In [1]:
from tqdm.auto import tqdm
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data_utils
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pathlib

In [2]:
def show_images_and_labels(device, model, test_loader, class_names):
    model.eval()
    with torch.no_grad():  # Disable gradient tracking
        images_per_class = {class_name: 0 for class_name in class_names}
        fig, axes = plt.subplots(10, 3, figsize=(15, 30))  # 10x3 grid
        
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            for image, label, pred in zip(images, labels, predicted):
                class_name = class_names[label.item()]
                if images_per_class[class_name] < 3:
                    ax = axes[label.item(), images_per_class[class_name]]
                    img = image.permute(1, 2, 0).cpu().numpy()
                    ax.imshow(img)
                    ax.set_title(f"Predicted: {class_names[pred.item()]}\nOriginal: {class_name}")
                    ax.axis('off')
                    images_per_class[class_name] += 1
            
            if all(count == 3 for count in images_per_class.values()):
                break
                
        # Prevent overlap
        plt.tight_layout()
        plt.show()

In [13]:
def data_generation(dataset_path, num_classes=10, data_augmentation=False, batch_size=32):
    
    # Mean and standard deviation values calculated from function get_mean_and_std

    mean = [0.4708, 0.4596, 0.3891]
    std = [0.1951, 0.1892, 0.1859]


    # Define transformations for training and testing data
    
    augment_transform = transforms.Compose([
        transforms.Resize((256, 256)), 
        transforms.RandomHorizontalFlip(), 
        transforms.RandomRotation(30), 
        transforms.ToTensor(),
        transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
    ])

    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
        ])
    
    test_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
    ])


    # Data augmentation (if data_augmentation = True) 

    train_dataset = datasets.ImageFolder(root = dataset_path + "train", transform=train_transform)
    test_dataset = datasets.ImageFolder(root = dataset_path + "val", transform=test_transform)
    
    
    # Split train dataset into train and validation sets

    train_data_class = dict()
    for c in range(num_classes):
        train_data_class[c] = [i for i, label in enumerate(train_dataset.targets) if label == c]

    val_data_indices = []
    val_ratio = 0.2  # 20% for validation
    for class_indices in train_data_class.values():
        num_val = int(len(class_indices) * val_ratio)
        val_data_indices.extend(random.sample(class_indices, num_val))


    # Create training and validation datasets

    train_data = torch.utils.data.Subset(train_dataset, [i for i in range(len(train_dataset)) if i not in val_data_indices])
    val_data = torch.utils.data.Subset(train_dataset, val_data_indices)


    # Create data loaders

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    if data_augmentation:
      augmented_dataset = datasets.ImageFolder(root = dataset_path + "train", transform=augment_transform)
      augmented_loader = DataLoader(augmented_dataset, batch_size=batch_size, shuffle=True)
      train_loader = torch.utils.data.ConcatDataset([train_loader.dataset, augmented_loader.dataset])
      train_loader = DataLoader(train_loader, batch_size=batch_size, shuffle=True)


    # Get class names
    classpath = pathlib.Path(dataset_path + "train")
    class_names = sorted([j.name.split('/')[-1] for j in classpath.iterdir() if j.name != ".DS_Store"])

    return train_loader, val_loader, test_loader, class_names

In [14]:
def trainCNN(device, train_loader, val_loader, test_loader, model, num_epochs=10, optimizer="Adam"):    
    criterion = nn.CrossEntropyLoss()
    if optimizer == "Adam":
        opt_func = optim.Adam(model.parameters(), lr=0.001)

    total_correct = 0
    total_samples = 0

    for epoch in tqdm(range(num_epochs)):
        model.train()  # Set the model to training mode
        running_loss = 0.0
        total_correct = 0
        total_samples = 0
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            opt_func.zero_grad()  # Zero the gradients
            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Compute the loss
            loss.backward()  # Backward pass
            opt_func.step()  # Update the parameters

            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)

            running_loss += loss.item() * inputs.size(0)
        loss = running_loss / len(train_loader.dataset)
        accuracy = total_correct / total_samples
        print(f"Epoch [{epoch+1}/{num_epochs}], Accuracy: {accuracy * 100:.2f}%, Loss: {loss:.4f}")
        # wandb.log({'accuracy': accuracy, 'loss': loss})


        # Validation
        model.eval()
        with torch.no_grad():
            val_total_correct = 0
            val_total_samples = 0
            val_running_loss = 0.0
            for val_inputs, val_labels in tqdm(val_loader):
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                val_outputs = model(val_inputs)
                val_loss = criterion(val_outputs, val_labels)

                _, val_predicted = torch.max(val_outputs, 1)
                val_total_correct += (val_predicted == val_labels).sum().item()
                val_total_samples += val_labels.size(0)

                val_running_loss += val_loss.item() * val_inputs.size(0)

            val_loss = val_running_loss / len(val_loader.dataset)
            val_accuracy = val_total_correct / val_total_samples
            print(f"Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {val_accuracy * 100:.2f}%, Validation Loss: {val_loss:.4f}")
            # wandb.log({'val_accuracy': val_accuracy, 'val_loss': val_loss})
    
        if epoch==num_epochs-1:
            model.eval()
            with torch.no_grad():
                test_total_correct = 0
                test_total_samples = 0
                test_running_loss = 0.0
                for test_inputs, test_labels in tqdm(test_loader):
                    test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
                    test_outputs = model(test_inputs)
                    test_loss = criterion(test_outputs, test_labels)
    
                    _, test_predicted = torch.max(test_outputs, 1)
                    test_total_correct += (test_predicted == test_labels).sum().item()
                    test_total_samples += test_labels.size(0)
    
                    test_running_loss += test_loss.item() * test_inputs.size(0)
    
                test_loss = test_running_loss / len(test_loader.dataset)
                test_accuracy = test_total_correct / test_total_samples
                print(f"Test Accuracy: {test_accuracy * 100:.2f}%, Test Loss: {test_loss:.4f}")

In [15]:
def feature_extraction(model, device):
    for params in model.parameters():
        params.requires_grad = False

def freeze_till_k(model, device, k):
    for params in model.parameters():
        for idx, child in enumerate(model.children()):
            # Freeze layers up to the k-th layer
            if idx < k:
                for param in child.parameters():
                    param.requires_grad = False
            else:
                # Stop iterating once we reach the k-th layer
                break

def no_freezing(model, device, k):
    for params in model.parameters():
        param.requires_grad = True

In [17]:
def main():
    dataset_path = '../inaturalist_12K/'  

    data_augmentation = False
    batch_size = 32
    num_classes = 10
    fine_tuning_method = 1
    k = 10

    def train():
        train_loader, val_loader, test_loader, class_names = data_generation(dataset_path, 
                                                                                 num_classes=10, 
                                                                                 data_augmentation=data_augmentation, 
                                                                                 batch_size=batch_size)
        print("Train: ", len(train_loader))
        print("Val: ", len(val_loader))
        print("Test: ", len(test_loader))

        device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        print("Device: ", device)
    
        model = models.googlenet(pretrained=True)

        if fine_tuning_method == 1:
            feature_extraction(model, device)
            model.fc = nn.Linear(model.fc.in_features, num_classes)
            model.to(device)
            trainCNN(device, train_loader, val_loader, test_loader, model, num_epochs=10, optimizer="Adam")
        
        elif fine_tuning_method == 2:
            freeze_till_k(model, device, k)
            model.fc = nn.Linear(model.fc.in_features, num_classes)
            model.to(device)
            trainCNN(device, train_loader, val_loader, test_loader, model, num_epochs=10, optimizer="Adam")

        else:
            no_freezing(model, device, k)
            model.fc = nn.Linear(model.fc.in_features, num_classes)
            model.to(device)
            trainCNN(device, train_loader, val_loader, test_loader, model, num_epochs=10, optimizer="Adam")
    train()
    
main()

NameError: name 'wandb' is not defined