In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import torchvision


In [2]:
# Define the model architecture
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.model = models.resnet50(pretrained=True)
        self.model = nn.Sequential(*list(self.model.children())[:-1])  # Remove the last layer

    def forward(self, x):
        with torch.no_grad():
            features = self.model(x)
        return features.view(features.size(0), -1)

class SimpleNN(nn.Module):
    def __init__(self, input_channels, output_dim):
        super(SimpleNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(256 * 8 * 8, output_dim)  # Adjust the size based on the actual output size after pooling

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

class BinaryClassifier(nn.Module):
    def __init__(self, input_dim):
        super(BinaryClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.fc2 = nn.Linear(512, 2)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class RealFakeDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

class TeacherStudentNetworks:
    def __init__(self, feature_dim, output_dim):
        self.teacher = SimpleNN(feature_dim, output_dim)
        self.student = SimpleNN(feature_dim, output_dim)
        
class FeatureAugmenter(nn.Module):
    def __init__(self, input_dim):
        super(FeatureAugmenter, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.fc2 = nn.Linear(512, input_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
def load_image_paths_and_labels(root_dir):
    image_paths = []
    labels = []
    
    for model in os.listdir(root_dir):
        model_path = os.path.join(root_dir, model)
        if os.path.isdir(model_path):
            for label in ['0_real', '1_fake']:
                label_path = os.path.join(model_path, label)
                if os.path.isdir(label_path):
                    for img_name in os.listdir(label_path):
                        img_path = os.path.join(label_path, img_name)
                        image_paths.append(img_path)
                        labels.append(0 if '0_real' in label else 1)
                else:
                    for obj in os.listdir(model_path):
                        obj_path = os.path.join(model_path, obj)
                        label_path = os.path.join(obj_path, label)
                        if os.path.isdir(label_path):
                            for img_name in os.listdir(label_path):
                                img_path = os.path.join(label_path, img_name)
                                image_paths.append(img_path)
                                labels.append(0 if '0_real' in label else 1)
    
    return image_paths, labels

# Optimized Training Functions
def train_teacher(teacher, feature_extractor, dataloader, device, num_epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(teacher.parameters(), lr=0.001)
    scaler = torch.cuda.amp.GradScaler()
    
    feature_extractor.to(device)
    teacher.to(device)
    
    for epoch in range(num_epochs):
        teacher.train()
        running_loss = 0.0
        for images, labels in tqdm(dataloader, desc=f'Teacher training {epoch+1}/{num_epochs}', unit='batch'):
            images, labels = images.to(device), labels.to(device)
            features = feature_extractor(images)
            
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = teacher(features)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
        
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')
    
    print('Finished Training Teacher')

def train_student(teacher, student, feature_extractor, augmenter, dataloader, device, num_epochs=10, margin=1.0):
    optimizer = optim.Adam(student.parameters(), lr=0.001)
    scaler = torch.cuda.amp.GradScaler()
    
    teacher.to(device)
    student.to(device)
    augmenter.to(device)
    feature_extractor.to(device)
    
    for epoch in range(num_epochs):
        student.train()
        running_loss_real = 0.0
        running_loss_fake = 0.0
        
        for images, labels in tqdm(dataloader, desc=f'Student training {epoch+1}/{num_epochs}', unit='batch'):
            images, labels = images.to(device), labels.to(device)
            features = feature_extractor(images)
            
            # Real images
            real_indices = (labels == 0)
            real_features = features[real_indices]
            if len(real_features) > 0:
                with torch.cuda.amp.autocast():
                    teacher_real = teacher(real_features)
                    student_real = student(real_features)
                    loss_real = torch.mean((teacher_real - student_real) ** 2)
                
                optimizer.zero_grad()
                scaler.scale(loss_real).backward()
                scaler.step(optimizer)
                scaler.update()
                
                running_loss_real += loss_real.item()
            
            # Fake images
            fake_indices = (labels == 1)
            fake_features = features[fake_indices]
            if len(fake_features) > 0:
                augmented_fake_features = augmenter(fake_features)
                with torch.cuda.amp.autocast():
                    teacher_fake = teacher(augmented_fake_features)
                    student_fake = student(augmented_fake_features)
                    discrepancy = torch.mean((teacher_fake / teacher_fake.norm(2, dim=1, keepdim=True) - student_fake / student_fake.norm(2, dim=1, keepdim=True)) ** 2)
                    loss_fake = torch.clamp(margin - discrepancy, min=0.0)
                
                optimizer.zero_grad()
                scaler.scale(loss_fake).backward()
                scaler.step(optimizer)
                scaler.update()
                
                running_loss_fake += loss_fake.item()
        
        print(f'Epoch {epoch+1}, Real Loss: {running_loss_real/len(dataloader)}, Fake Loss: {running_loss_fake/len(dataloader)}')
    
    print('Finished Training Student')

def train_binary_classifier(teacher, student, classifier, feature_extractor, dataloader, device, num_epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=0.001)
    scaler = torch.cuda.amp.GradScaler()
    
    teacher.to(device)
    student.to(device)
    classifier.to(device)
    feature_extractor.to(device)
    
    for epoch in range(num_epochs):
        classifier.train()
        running_loss = 0.0
        
        for images, labels in tqdm(dataloader, desc=f'Classifier training {epoch+1}/{num_epochs}', unit='batch'):
            images, labels = images.to(device), labels.to(device)
            features = feature_extractor(images)
            
            with torch.no_grad():
                teacher_outputs = teacher(features)
                student_outputs = student(features)
            discrepancies = (teacher_outputs - student_outputs) ** 2
            
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = classifier(discrepancies)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
        
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')
    
    print('Finished Training Binary Classifier')

In [4]:
# Data preparation
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])


In [5]:

# Load image paths and labels
image_paths, labels = load_image_paths_and_labels(r'C:\Users\Danila\VSU\vsu_common_rep\vsu_common_rep\2year\2term\project\image_classification\content\CNN_synth\train_set')

# Create a dataset instance
full_dataset = RealFakeDataset(image_paths, labels, transform=transform)

# Define the split ratio
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

# Split the dataset
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

# Create data loaders with more workers for faster loading
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, pin_memory=True)

# Initialize networks
feature_extractor = FeatureExtractor()
feature_dim = 2048  # ResNet50 output feature dimension
teacher_student = TeacherStudentNetworks(feature_dim, 2)
augmenter = FeatureAugmenter(feature_dim)
binary_classifier = BinaryClassifier(2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train the teacher network
train_teacher(teacher_student.teacher, feature_extractor, train_loader, device)
torch.save(teacher_student.teacher.state_dict(), 'models/teacher_10ep.pth')

# Train the student network
train_student(teacher_student.teacher, teacher_student.student, feature_extractor, augmenter, train_loader, device)
torch.save(teacher_student.student.state_dict(), 'models/student_10ep.pth')

# Train the binary classifier
train_binary_classifier(teacher_student.teacher, teacher_student.student, binary_classifier, feature_extractor, train_loader, device)
torch.save(binary_classifier.state_dict(), 'models/classifier_10ep.pth')

# Function to evaluate the model on test data
def evaluate_model(test_loader):
    teacher_student.teacher.eval()
    teacher_student.student.eval()
    binary_classifier.eval()
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            features = feature_extractor(images)
            teacher_outputs = teacher_student.teacher(features)
            student_outputs = teacher_student.student(features)
            discrepancies = (teacher_outputs - student_outputs) ** 2
            outputs = binary_classifier(discrepancies)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    print(f'Accuracy: {accuracy * 100:.2f}%')

# Evaluate the model
evaluate_model(test_loader)

Teacher training 1/10:   2%|█▎                                                    | 56/2259 [00:24<16:07,  2.28batch/s]


KeyboardInterrupt: 