In [72]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import torchvision
from time import gmtime, strftime

In [101]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.model = models.resnet50(pretrained=True)
        self.model = nn.Sequential(*list(self.model.children())[:-1])  # Remove the last layer

    def forward(self, x):
        with torch.no_grad():
            features = self.model(x)
        return features.view(features.size(0), -1)
        
class SimpleNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class TeacherStudentNetworks:
    def __init__(self, feature_dim, output_dim):
        self.teacher = SimpleNN(feature_dim, output_dim)
        self.student = SimpleNN(feature_dim, output_dim)
        
class FeatureAugmenter:
    def __init__(self, input_dim):
        self.feature_augmenter = SimpleNN(input_dim, input_dim)
        
class BinaryClassifier:
    def __init__(self, input_dim):
        self.binary_classifier = SimpleNN(input_dim, 2)

class RealFakeDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label



In [115]:
def load_image_paths_and_labels(root_dir):
    image_paths = []
    labels = []
    
    for model in os.listdir(root_dir):
        model_path = os.path.join(root_dir, model)
        if os.path.isdir(model_path):
            for label in ['0_real', '1_fake']:
                label_path = os.path.join(model_path, label)
                if os.path.isdir(label_path):
                    for img_name in os.listdir(label_path):
                        img_path = os.path.join(label_path, img_name)
                        image_paths.append(img_path)
                        labels.append(0 if '0_real' in label else 1)
                else:
                    for obj in os.listdir(model_path):
                        obj_path = os.path.join(model_path, obj)
                        label_path = os.path.join(obj_path, label)
                        if os.path.isdir(label_path):
                            for img_name in os.listdir(label_path):
                                img_path = os.path.join(label_path, img_name)
                                image_paths.append(img_path)
                                labels.append(0 if '0_real' in label else 1)
    
    return image_paths, labels
    
def train_teacher(teacher, feature_extractor, dataloader, device, num_epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(teacher.parameters(), lr=0.01, weight_decay=0.001)
    feature_extractor.to(device)
    teacher.to(device)
    teacher.train()
    feature_extractor.eval()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in tqdm(dataloader, desc=f'Teacher training {epoch+1}/{num_epochs}', unit='batch'):
            images, labels = images.to(device), labels.to(device)
            features = feature_extractor(images)
            
            optimizer.zero_grad()
            outputs = teacher(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')
        save_checkpoint(teacher, filename='teacher_cpt', epoch=epoch)
        
    print('Finished Training Teacher')
    

def train_studen_augmenter(teacher, student, feature_extractor, augmenter, dataloader, device, num_epochs=10, margin=1.0):
    optimizer = optim.Adam(student.parameters(), lr=0.01, weight_decay=0.001)
    optimizer_augmenter = optim.Adam(augmenter.parameters(), lr=0.01, weight_decay=0.001)
    teacher.to(device)
    student.to(device)
    augmenter.to(device)
    feature_extractor.to(device)
    student.train()
    teacher.eval()
    feature_extractor.eval()
    augmenter.train()
    
    for epoch in range(num_epochs):

        running_loss_real = 0.0
        running_loss_fake = 0.0
        running_loss_augmenter = 0.0
        
        for images, labels in tqdm(dataloader, desc=f'Student and Augmenter training {epoch+1}/{num_epochs}', unit='batch'):
            images, labels = images.to(device), labels.to(device)
            features = feature_extractor(images)
            
            # Real images
            real_indices = (labels == 0)
            real_features = features[real_indices]
            if len(real_features) > 0:
                teacher_real = teacher(real_features)
                student_real = student(real_features)
                loss_real = torch.mean((teacher_real - student_real) ** 2)
                
                optimizer.zero_grad()
                loss_real.backward()
                optimizer.step()
                
                running_loss_real += loss_real.item()
            
            # Fake images
            fake_indices = (labels == 1)
            fake_features = features[fake_indices]
            if len(fake_features) > 0:
                augmented_fake_features = augmenter(fake_features)
                teacher_fake = teacher(augmented_fake_features)
                student_fake = student(augmented_fake_features)
                discrepancy = torch.mean((teacher_fake / teacher_fake.norm(2, dim=1, keepdim=True) - student_fake / student_fake.norm(2, dim=1, keepdim=True)) ** 2)
                loss_fake = torch.clamp(margin - discrepancy, min=0.0)
                
                optimizer.zero_grad()
                loss_fake.backward()
                optimizer.step()
                
                loss_augmenter = torch.mean((teacher_fake - student_fake) ** 2)
                
                optimizer_augmenter.zero_grad()
                loss_augmenter.backward()
                optimizer_augmenter.step()
                
                
                running_loss_fake += loss_fake.item()
                running_loss_augmenter+= loss_augmenter.item()
                
        print(f'Epoch {epoch+1}, Real Loss: {running_loss_real/len(dataloader)}, Fake Loss: {running_loss_fake/len(dataloader)}')
        save_checkpoint(student, filename='student_cpt', epoch=epoch)
        save_checkpoint(augmenter, filename='augmenter_cpt', epoch=epoch)
        
    print('Finished Training Student and Augmenter')
        
def train_student(teacher, student, feature_extractor, augmenter, dataloader, device, num_epochs=10, margin=1.0):
    optimizer = optim.Adam(student.parameters(), lr=0.01, weight_decay=0.001)
    teacher.to(device)
    student.to(device)
    augmenter.to(device)
    feature_extractor.to(device)
    student.train()
    teacher.eval()
    feature_extractor.eval()
    augmenter.eval()
    
    for epoch in range(num_epochs):

        running_loss_real = 0.0
        running_loss_fake = 0.0
        
        for images, labels in tqdm(dataloader, desc=f'Student training {epoch+1}/{num_epochs}', unit='batch'):
            images, labels = images.to(device), labels.to(device)
            features = feature_extractor(images)
            
            # Real images
            real_indices = (labels == 0)
            real_features = features[real_indices]
            if len(real_features) > 0:
                teacher_real = teacher(real_features)
                student_real = student(real_features)
                loss_real = torch.mean((teacher_real - student_real) ** 2)
                
                optimizer.zero_grad()
                loss_real.backward()
                optimizer.step()
                
                running_loss_real += loss_real.item()
            
            # Fake images
            fake_indices = (labels == 1)
            fake_features = features[fake_indices]
            if len(fake_features) > 0:
                augmented_fake_features = augmenter(fake_features)
                teacher_fake = teacher(augmented_fake_features)
                student_fake = student(augmented_fake_features)
                discrepancy = torch.mean((teacher_fake / teacher_fake.norm(2, dim=1, keepdim=True) - student_fake / student_fake.norm(2, dim=1, keepdim=True)) ** 2)
                loss_fake = torch.clamp(margin - discrepancy, min=0.0)
                
                optimizer.zero_grad()
                loss_fake.backward()
                optimizer.step()
                
                running_loss_fake += loss_fake.item()
        
        print(f'Epoch {epoch+1}, Real Loss: {running_loss_real/len(dataloader)}, Fake Loss: {running_loss_fake/len(dataloader)}')
        save_checkpoint(student, filename='student_cpt', epoch=epoch)
        
    print('Finished Training Student')
    
def train_binary_classifier(teacher, student, classifier, feature_extractor, dataloader, device, num_epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=0.01, weight_decay=0.001)
    teacher.to(device)
    student.to(device)
    classifier.to(device)
    feature_extractor.to(device)
    classifier.train()
    teacher.eval()
    student.eval()
    feature_extractor.eval()
    augmenter.eval()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        for images, labels in tqdm(dataloader, desc=f'Classifier training {epoch+1}/{num_epochs}', unit='batch'):
            images, labels = images.to(device), labels.to(device)
            features = feature_extractor(images)
            
            with torch.no_grad():
                teacher_outputs = teacher(features)
                student_outputs = student(features)
            discrepancies = (teacher_outputs - student_outputs) ** 2
            outputs = classifier(discrepancies)
            
            optimizer.zero_grad()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')
        save_checkpoint(augmenter, filename='bin_classifier_cpt', epoch=epoch)
        
    print('Finished Training Binary Classifier')


def train_augmenter(teacher, student, feature_extractor, augmenter, dataloader, device, num_epochs=10):
    optimizer = optim.Adam(augmenter.parameters(), lr=0.001, weight_decay=0.0001)
    teacher.to(device)
    student.to(device)
    augmenter.to(device)
    feature_extractor.to(device)
    teacher.eval()
    student.eval()
    feature_extractor.eval()
    augmenter.train()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        for images, labels in tqdm(dataloader, desc=f'Augmenter training {epoch+1}/{num_epochs}', unit='batch'):
            images, labels = images.to(device), labels.to(device)
            features = feature_extractor(images)
            
            # Fake images
            fake_indices = (labels == 1)
            fake_features = features[fake_indices]
            if len(fake_features) > 0:
                augmented_fake_features = augmenter(fake_features)
                with torch.no_grad:
                    teacher_fake = teacher(augmented_fake_features)
                    student_fake = student(augmented_fake_features)
                loss = torch.mean((teacher_fake - student_fake) ** 2)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                running_loss+= loss.item()
                
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')
        save_checkpoint(augmenter, filename='augmenter_cpt', epoch=epoch)
        
    print('Finished Training Augmenter')

In [4]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])


In [80]:
def save_checkpoint(model, filename='checkpoint', epoch = 0):
    modelfilename = filename + f'_epoch{epoch}.pth.tar'
    torch.save(model.state_dict(), f'models/{modelfilename}')
    
    print(f'Saved checkpoint as: {modelfilename}')
def load_checkpoint(model, path):
    model.load_state_dict(torch.load(path))
    print(f'Loaded model: {path}')
    
def add_log(msg):
     with open('log.txt', 'a') as file:
        time = strftime("%H:%M", gmtime())
        file.write(time + ': ' + msg + '\n')

In [109]:
# Load image paths and labels
image_paths, labels = load_image_paths_and_labels('C:/Users/Danila/VSU/vsu_common_rep/vsu_common_rep/2year/2term/project/image_classification/content/CNN_synth/train_set/')

# Create a dataset instance 
full_dataset = RealFakeDataset(image_paths, labels, transform=transform)

# Define the split ratio
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

# Split the dataset
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [102]:
# Initialize networks
feature_extractor = FeatureExtractor()
feature_dim = 2048  # ResNet50 output feature dimension
teacher_student = TeacherStudentNetworks(feature_dim, 2)
augmenter = FeatureAugmenter(feature_dim)
binary_classifier = BinaryClassifier(2)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_student.teacher.to(device)
teacher_student.student.to(device)
binary_classifier.binary_classifier.to(device)
feature_extractor.to(device)
augmenter.feature_augmenter.to(device)

In [None]:
# Train the teacher network
train_teacher(teacher_student.teacher, feature_extractor, train_loader, device, 10)

In [116]:
# Train the student and the augmenter
train_studen_augmenter(teacher_student.teacher, teacher_student.student, feature_extractor, augmenter.feature_augmenter, train_loader, device, 10)

# Train the binary classifier
train_binary_classifier(teacher_student.teacher, teacher_student.student, binary_classifier.binary_classifier, feature_extractor, train_loader, device, 10)

testacc = test_binary_classifier(teacher_student.teacher, teacher_student.student, binary_classifier.binary_classifier, feature_extractor, test_loader, device)
mytestacc = test_binary_classifier(teacher_student.teacher, teacher_student.student, binary_classifier.binary_classifier, feature_extractor, mytest_loader, device)
add_log(f'Test dataset accuracy: {testacc}')
add_log(f'My test dataset accuracy: {mytestacc}')

Student and Augmenter training 1/10:   0%|                                                 | 0/2259 [00:00<?, ?batch/s]


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512, 2]], which is output 0 of AsStridedBackward0, is at version 10; expected version 9 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Teacher training 1/10: 100%|████████████████████████████████████████████████████| 2259/2259 [14:55<00:00,  2.52batch/s]


Epoch 1, Loss: 0.7866641403352755
Saved checkpoint as: teacher_cpt_epoch0.pth.tar


Teacher training 2/10: 100%|████████████████████████████████████████████████████| 2259/2259 [15:18<00:00,  2.46batch/s]


Epoch 2, Loss: 0.7466870907029519
Saved checkpoint as: teacher_cpt_epoch1.pth.tar


Teacher training 3/10: 100%|████████████████████████████████████████████████████| 2259/2259 [15:05<00:00,  2.50batch/s]


Epoch 3, Loss: 0.7773583210533095
Saved checkpoint as: teacher_cpt_epoch2.pth.tar


Teacher training 4/10: 100%|████████████████████████████████████████████████████| 2259/2259 [14:05<00:00,  2.67batch/s]


Epoch 4, Loss: 0.764106681622357
Saved checkpoint as: teacher_cpt_epoch3.pth.tar


Teacher training 5/10: 100%|████████████████████████████████████████████████████| 2259/2259 [13:55<00:00,  2.71batch/s]


Epoch 5, Loss: 0.7977485331406157
Saved checkpoint as: teacher_cpt_epoch4.pth.tar


Teacher training 6/10: 100%|████████████████████████████████████████████████████| 2259/2259 [13:57<00:00,  2.70batch/s]


Epoch 6, Loss: 0.7842030296973709
Saved checkpoint as: teacher_cpt_epoch5.pth.tar


Teacher training 7/10: 100%|████████████████████████████████████████████████████| 2259/2259 [13:52<00:00,  2.71batch/s]


Epoch 7, Loss: 0.7754422654985273
Saved checkpoint as: teacher_cpt_epoch6.pth.tar


Teacher training 8/10: 100%|████████████████████████████████████████████████████| 2259/2259 [13:53<00:00,  2.71batch/s]


Epoch 8, Loss: 0.7679152669184711
Saved checkpoint as: teacher_cpt_epoch7.pth.tar


Teacher training 9/10: 100%|████████████████████████████████████████████████████| 2259/2259 [13:54<00:00,  2.71batch/s]


Epoch 9, Loss: 0.7852224255621354
Saved checkpoint as: teacher_cpt_epoch8.pth.tar


Teacher training 10/10: 100%|███████████████████████████████████████████████████| 2259/2259 [13:55<00:00,  2.70batch/s]


Epoch 10, Loss: 0.7877704605506342
Saved checkpoint as: teacher_cpt_epoch9.pth.tar
Finished Training Teacher


Student and Augmenter training 1/10:   0%|                                                 | 0/2259 [00:01<?, ?batch/s]


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:

# for i in range(10):
#     # Train the student network
#     train_student(teacher_student.teacher, teacher_student.student, feature_extractor, augmenter, train_loader, device, 1, margin= 1.0)
#     save_checkpoint(teacher_student.teacher, filename='student_cpt', epoch=i)
    
#     # Traun the augmenter
#     train_augmenter(teacher_student.teacher, teacher_student.student, feature_extractor, augmenter, train_loader, device, 1)
#     save_checkpoint(teacher_student.teacher, filename='augmenter_cpt', epoch=i)
    

In [25]:
def test_binary_classifier(teacher, student, classifier, feature_extractor, dataloader, device):
    classifier.to(device)
    teacher.to(device) 
    student.to(device) 
    feature_extractor.to(device) 
    classifier.eval() # Set the classifier to evaluation mode
    teacher.eval()
    student.eval()
    feature_extractor.eval()

    all_labels = []
    all_preds = []
    
    with torch.no_grad():  # Disable gradient calculation for testing
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            # Extract features
            features = feature_extractor(images)
            
            # Get outputs from teacher and student networks
            teacher_outputs = teacher(features)
            student_outputs = student(features)
            
            # Calculate discrepancy
            discrepancy = (teacher_outputs - student_outputs) ** 2
            
            # Get predictions from binary classifier
            outputs = classifier(discrepancy)
            _, preds = torch.max(outputs, 1)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    print(f'Test Accuracy: {accuracy * 100:.2f}%')
    return accuracy

In [24]:
test_binary_classifier(teacher_student.teacher, teacher_student.student, binary_classifier, feature_extractor, test_loader, device)

TypeError: 'type' object does not support the context manager protocol

In [10]:
mytest_image_paths, mytest_labels = load_image_paths_and_labels("C:/Users/Danila/VSU/vsu_common_rep/vsu_common_rep/2year/2term/project/image_classification/content/CNN_synth/test_set/")
mytest_dataset = RealFakeDataset(mytest_image_paths, mytest_labels, transform=transform)
mytest_loader = DataLoader(mytest_dataset, batch_size=32, shuffle=False)

In [11]:
test_binary_classifier(teacher_student.teacher, teacher_student.student, binary_classifier, feature_extractor, mytest_loader, device)

Test Accuracy: 49.07%


0.4906609195402299

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize networks
feature_extractor = FeatureExtractor()
feature_dim = 2048  # ResNet50 output feature dimension
teacher_student = TeacherStudentNetworks(feature_dim, 512)
augmenter = FeatureAugmenter(feature_dim)
binary_classifier = BinaryClassifier(2048)


load_checkpoint(teacher_student.student, 'models_before_augmenter_training/student_cpt_epoch6.pth.tar')
load_checkpoint(teacher_student.teacher, 'models_before_augmenter_training/teacher_cpt_epoch6.pth.tar')

load_checkpoint(binary_classifier, 'models_before_augmenter_training/bin_classifier_cpt_epoch6.pth.tar')

In [105]:
save_checkpoint(teacher_student.teacher, filename='teacher_test_cpt', epoch=0)
save_checkpoint(teacher_student.student, filename='student_test_cpt', epoch=0)
save_checkpoint(augmenter.feature_augmenter, filename='augmenter_test_cpt', epoch=0)
save_checkpoint(binary_classifier.binary_classifier, filename='bin_classifier_test_cpt', epoch=0)

Saved checkpoint as: teacher_test_cpt_epoch0.pth.tar
Saved checkpoint as: student_test_cpt_epoch0.pth.tar
Saved checkpoint as: augmenter_test_cpt_epoch0.pth.tar
Saved checkpoint as: bin_classifier_test_cpt_epoch0.pth.tar


In [106]:
load_checkpoint(teacher_student.student, 'models/teacher_test_cpt_epoch0.pth.tar')
load_checkpoint(teacher_student.teacher, 'models/student_test_cpt_epoch0.pth.tar')
load_checkpoint(augmenter.feature_augmenter, 'models/augmenter_test_cpt_epoch0.pth.tar')
load_checkpoint(binary_classifier.binary_classifier, 'models/bin_classifier_test_cpt_epoch0.pth.tar')

Loaded model: models/teacher_test_cpt_epoch0.pth.tar
Loaded model: models/student_test_cpt_epoch0.pth.tar
Loaded model: models/augmenter_test_cpt_epoch0.pth.tar
Loaded model: models/bin_classifier_test_cpt_epoch0.pth.tar


In [35]:
train_augmenter(teacher_student.teacher, teacher_student.student, feature_extractor, augmenter, train_loader, device, 1)

Student training 1/1:   1%|▋                                                      | 26/2259 [00:10<14:41,  2.53batch/s]


KeyboardInterrupt: 