In [1]:
import torch
from torchvision import datasets, transforms
import numpy as np
from matplotlib import pyplot as plt
from utils import plot_tsne
import numpy as np
import random
import argparse
from AutoEncoderDecoder import EncoderCIFAR, DecoderCIFAR, ClassifierCIFAR, ProjectionHead,SupConLoss, train_encoder_cifar, plot_reconstruction, plot_images_with_labels, trainEncoderMNIST123
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', default=0, type=int, help='Seed for random number generators')
    parser.add_argument('--data-path', default="~/datasets/cv_datasets/data", type=str, help='Path to dataset')
    parser.add_argument('--batch-size', default=8, type=int, help='Size of each batch')
    parser.add_argument('--latent-dim', default=128, type=int, help='Encoding dimension')
    parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help='Default device to use')
    parser.add_argument('--mnist', action='store_true', default=False, help='Use MNIST (True) or CIFAR10 (False) data')
    parser.add_argument('--self-supervised', action='store_true', default=False, help='Train self-supervised or jointly with classifier')
    parser.add_argument('--debug', action='store_true', default=False, help='Enable debugging for dataloader')
    args, unknown =  parser.parse_args()
    return args

NUM_CLASSES = 10

In [3]:
def freeze_seeds(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [4]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

In [5]:
freeze_seeds(0)

In [6]:
import os
current_dir = os.getcwd()
train_dataset = datasets.CIFAR10(root=current_dir, train=True, download=False, transform=transform)
class_names = train_dataset.classes 
test_dataset = datasets.CIFAR10(root=current_dir, train=False, download=False, transform=transform)

In [7]:
train_size = int(0.8 * len(train_dataset))
val_size   = len(train_dataset) - train_size

train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Create DataLoader 
# dl_train = DataLoader(train_dataset, batch_size=16, shuffle=True)
# dl_test  = DataLoader(test_dataset, batch_size=16, shuffle=False)
# dl_val   = DataLoader(val_dataset, batch_size=16, shuffle=False)

dl_train = DataLoader(train_dataset, batch_size=128, shuffle=True)
dl_test  = DataLoader(test_dataset, batch_size=128, shuffle=False)
dl_val   = DataLoader(val_dataset, batch_size=128, shuffle=False)

im_size = train_dataset[0][0].shape

# Initialize the autoencoder and the optimizer
latent_dim = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = EncoderCIFAR(in_channels=im_size[0], latent_dim=128)
decoder= DecoderCIFAR(latent_dim=128, out_channels=im_size[0])
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)


In [8]:
###### 1.2.1 Training the autoencoder - selfsupervised ######
num_epochs = 10
for epoch in range(num_epochs):
    #train
    epoch_loss = 0.0
    encoder.train()
    decoder.train()
    for data in dl_train:
        images, _ = data
        images = images.to(device)
        optimizer.zero_grad()
        
        latent = encoder(images)
        reconstructed = decoder(latent)
        loss = criterion(reconstructed, images)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(dl_train) 
    print(f"Epoch {epoch+1}, Loss: {avg_epoch_loss}")

    #validation 
    encoder.eval()
    decoder.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in dl_val:
            images, _ = data
            images = images.to(device)
            latent = encoder(images)
            reconstructed = decoder(latent)
            loss = criterion(reconstructed, images)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(dl_val)
    print(f"Ephoch {epoch+1}, Validation Loss: {avg_val_loss}")
    
    # Plot reconstructions test images
    with torch.no_grad():
        test_images, _ = next(iter(dl_test))
        test_images = test_images.to(device)
        latent = encoder(test_images)
        reconstructed = decoder(latent)
        
        reconstruction_loss = criterion(reconstructed, test_images)
        print(f"Test Reconstruction Loss: {reconstruction_loss.item():.4f}")
        plot_reconstruction(test_images, reconstructed)
        

KeyboardInterrupt: 

In [None]:
for param in encoder.parameters(): # Freeze encoder weights
    param.requires_grad = False

In [None]:
# Initialize the classifier and optimizer
classifier = ClassifierCIFAR(latent_dim=128, num_classes=10).to(device)
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
criterion_classifier = torch.nn.CrossEntropyLoss()

In [None]:
#Training classifier
num_epochs = 10
for epoch in range(num_epochs):
    classifier.train() 
    epoch_loss = 0.0
    correct = 0
    total = 0

    for data in dl_train:
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            latent = encoder(images)

        classifier_optimizer.zero_grad() 
        outputs = classifier(latent) 

        loss = criterion_classifier(outputs, labels)  # Cross-Entropy Loss
        loss.backward()  # Backpropagate
        classifier_optimizer.step()  # Update classifier weights

        epoch_loss += loss.item()  
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_loss = epoch_loss / len(dl_train)
    accuracy = 100 * correct / total

    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

    # Validation 
    classifier.eval()  
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():  # Disable gradient calculation for validation
        for data in dl_val:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            latent = encoder(images)
            outputs = classifier(latent)
            loss = criterion_classifier(outputs, labels)

            val_loss += loss.item()  
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    avg_val_loss = val_loss / len(dl_val)
    val_accuracy = 100 * val_correct / val_total

    print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")


In [None]:
# Test the classifier on the test set
classifier.eval() 
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for data in dl_test:
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        latent = encoder(images)
        outputs = classifier(latent)

        loss = criterion_classifier(outputs, labels)
        test_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Loss: {test_loss / len(dl_test):.4f}")
print(f"Test Accuracy: {accuracy:.2f}%")


# Visualize the images along with their true and predicted labels
plot_images_with_labels(images, labels, predicted, class_names, num_images=10)

encoder.eval()  # Set encoder to evaluation mode
plot_tsne(encoder, dl_test, device, 'CIFAR_1_2_1')

In [None]:
###1.2.2 - Classification-Guided Encoding###

encoder = EncoderCIFAR(in_channels=im_size[0], latent_dim=128) # reset encoder
classifier = ClassifierCIFAR(latent_dim=128, num_classes=NUM_CLASSES).to(device) # reset classifier

optimizer = torch.optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=1e-3)
criterion_classifier = torch.nn.CrossEntropyLoss()


In [None]:
# Training loop 
num_epochs = 10
for epoch in range(num_epochs):
    encoder.train()  # training mode
    classifier.train()
    epoch_loss = 0.0
    correct = 0
    total = 0

    for data in dl_train:  
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        latent = encoder(images)  
        outputs = classifier(latent)

        loss = criterion_classifier(outputs, labels)
        loss.backward() 
        optimizer.step()  

        epoch_loss += loss.item()  
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_epoch_loss = epoch_loss / len(dl_train)  
    accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}, Loss: {avg_epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")

# Validation phase
encoder.eval()  
classifier.eval()  
val_loss = 0.0
val_correct = 0
val_total = 0

with torch.no_grad():  # Disable gradient calculation for validation
    for data in dl_val:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
            
        latent = encoder(images)
        outputs = classifier(latent)
            
        loss_val = criterion_classifier(outputs, labels)
        val_loss += loss_val.item()
            
        _, predicted = torch.max(outputs.data, 1)
        val_total += labels.size(0)
        val_correct += (predicted == labels).sum().item()
    
avg_val_loss = val_loss / len(dl_val)
val_accuracy = 100 * val_correct / val_total
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

In [None]:
# Evaluate the classifier on the test set
encoder.eval()  
classifier.eval()
test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for data in dl_test:
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        latent = encoder(images)  
        outputs = classifier(latent)

        loss = criterion_classifier(outputs, labels)
        test_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Loss: {test_loss / len(dl_test):.4f}")
print(f"Test Accuracy: {test_accuracy:.2f}%")
# Visualize the images along with their true and predicted labels
plot_images_with_labels(images, labels, predicted, class_names, num_images=10)
plot_tsne(encoder, dl_test, device, 'CIFAR_1_2_2')

In [7]:
# ### 1.2.3 ###

transform = transforms.Compose([
    transforms.RandomResizedCrop(28, scale=(0.9, 1.0)),  # Crop but keep almost full size
    transforms.RandomHorizontalFlip(p=0.5),  # Only flipping
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

current_dir = os.getcwd()
train_dataset = datasets.CIFAR10(root=current_dir, train=True, download=False, transform=transform)
class_names = train_dataset.classes 
test_dataset = datasets.CIFAR10(root=current_dir, train=False, download=False, transform=transform)
train_size = int(0.8 * len(train_dataset))
val_size   = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
dl_train = DataLoader(train_dataset, batch_size=128, shuffle=True)
dl_test  = DataLoader(test_dataset, batch_size=128, shuffle=False)
dl_val   = DataLoader(val_dataset, batch_size=128, shuffle=False)

im_size = train_dataset[0][0].shape
latent_dim = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # # Train the encoder
encoder = EncoderCIFAR(in_channels=im_size[0], latent_dim=128) # reset encoder
projection_head = ProjectionHead(in_dim=128, out_dim=128).to(device)
# encoder123 = torch.nn.Sequential(encoder, projection_head)
train_encoder_cifar(encoder, projection_head, epochs=10, dl_train=dl_train, device=device)

Training 1.2.3 contrastive encoder for CIFAR with SupConLoss
Epoch 1/10, Loss: 1.912023
Epoch 2/10, Loss: 1.641444
Epoch 3/10, Loss: 1.497223
Epoch 4/10, Loss: 1.384064
Epoch 5/10, Loss: 1.293552
Epoch 6/10, Loss: 1.223186
Epoch 7/10, Loss: 1.152581
Epoch 8/10, Loss: 1.098858
Epoch 9/10, Loss: 1.040248
Epoch 10/10, Loss: 0.996470


In [10]:
classifier = ClassifierCIFAR(latent_dim=128, num_classes=10).to(device)
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
criterion_classifier = torch.nn.CrossEntropyLoss()

# Training loop 
num_epochs = 10
for epoch in range(num_epochs):
    encoder.train()  # training mode
    classifier.train()
    epoch_loss = 0.0
    correct = 0
    total = 0

    for data in dl_train:  
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        classifier_optimizer.zero_grad()
        latent = encoder(images)
        outputs = classifier(latent)

        loss = criterion_classifier(outputs, labels)
        loss.backward() 
        classifier_optimizer.step()  

        epoch_loss += loss.item()  
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_epoch_loss = epoch_loss / len(dl_train)  
    accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}, Loss: {avg_epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")

    # Validation phase
    encoder.eval()  
    classifier.eval()  
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():  # Disable gradient calculation for validation
        for data in dl_val:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            latent = encoder(images)
            outputs = classifier(latent)

            loss_val = criterion_classifier(outputs, labels)
            val_loss += loss_val.item()

            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    avg_val_loss = val_loss / len(dl_val)
    val_accuracy = 100 * val_correct / val_total
    print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

check
Epoch 1, Loss: 0.7933, Accuracy: 73.27%
Epoch 1, Validation Loss: 0.8350, Validation Accuracy: 71.25%
Epoch 2, Loss: 0.7313, Accuracy: 74.96%
Epoch 2, Validation Loss: 0.8302, Validation Accuracy: 71.80%
Epoch 3, Loss: 0.7204, Accuracy: 75.19%
Epoch 3, Validation Loss: 0.8114, Validation Accuracy: 72.14%
Epoch 4, Loss: 0.7158, Accuracy: 75.57%
Epoch 4, Validation Loss: 0.8186, Validation Accuracy: 72.47%
Epoch 5, Loss: 0.7032, Accuracy: 75.71%
Epoch 5, Validation Loss: 0.8312, Validation Accuracy: 71.80%
Epoch 6, Loss: 0.7002, Accuracy: 76.00%
Epoch 6, Validation Loss: 0.8081, Validation Accuracy: 72.83%
Epoch 7, Loss: 0.6995, Accuracy: 75.83%
Epoch 7, Validation Loss: 0.8182, Validation Accuracy: 72.48%
Epoch 8, Loss: 0.6975, Accuracy: 75.87%
Epoch 8, Validation Loss: 0.8146, Validation Accuracy: 71.97%
Epoch 9, Loss: 0.6960, Accuracy: 75.86%
Epoch 9, Validation Loss: 0.8114, Validation Accuracy: 72.49%
Epoch 10, Loss: 0.6924, Accuracy: 76.11%
Epoch 10, Validation Loss: 0.8124, 

In [11]:
# Evaluate the classifier on the test set
encoder.eval()  
classifier.eval()
test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for data in dl_test:
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        latent = encoder(images)  
        outputs = classifier(latent)

        loss = criterion_classifier(outputs, labels)
        test_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Loss: {test_loss / len(dl_test):.4f}")
print(f"Test Accuracy: {test_accuracy:.2f}%")
# Visualize the images along with their true and predicted labels
plot_tsne(encoder, dl_test, device, 'CIFAR_1_2_3')

Test Loss: 0.8111
Test Accuracy: 71.77%
