# Readme for the Code
The provided code is a deep learning pipeline for training a model using the SeResNet architecture on a dataset of face images. The pipeline includes data loading, model creation, training, validation, and testing. Here's a breakdown of the major components and processes in the code:

#### Libraries
The code begins by installing and importing necessary libraries such as PyTorch, torchvision, wandb, and other relevant packages.

#### Data Loading and Preprocessing
The code then sets up the data directories, defines transformations for the training and validation datasets, and creates data loaders for the training, validation, and test sets.

#### Visualization
A visualization section is included to display a few images from the dataset as a sanity check for data augmentation.

#### Network Architecture
The SeResNetNetwork class is defined, which creates an instance of the SeResNet model using the timm library. The model is then moved to the GPU if available, and the total number of parameters is calculated.

#### Training
The training process is defined, including the train function for training the model, and the validate function for evaluating the model on the validation set. The code also includes the use of mixed precision training and a learning rate scheduler.

#### Wandb Integration
The code integrates with Weights & Biases (wandb) for experiment tracking and visualization of training and validation metrics.

#### Testing
The testing process is defined, including the test function for evaluating the model on the test set and calculating the test accuracy.

#### Results
The code concludes by printing the test accuracy achieved by the trained model.

Overall, the code provides a comprehensive deep learning pipeline for training and evaluating a SeResNet model on a face image dataset.

The code is well-structured and includes detailed comments to explain each section and its purpose. It also integrates with wandb for experiment tracking and visualization of training and validation metrics.

The provided code demonstrates best practices for deep learning model training, including data loading, model creation, training, validation, testing, and experiment tracking using wandb.

# Libraries

In [None]:
!pip install wandb --quiet

In [None]:
import torch
from torch import nn
import torchvision #This library is used for image-based operations (Augmentations)
import os
import gc
from tqdm import tqdm
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
import wandb
import glob
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

In [None]:
config = {
    'batch_size': 64, # Increase this if your GPU can handle it
    'lr': 1e-3,
    'epochs': 100, 
}

In [None]:
DATA_DIR    = '/kaggle/input/data-files/kaggle/working/dataset/Faceswap_images'
TRAIN_DIR   = os.path.join(DATA_DIR, "train")
VAL_DIR     = os.path.join(DATA_DIR, "val")
TEST_DIR    = os.path.join(DATA_DIR, "test")


train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.RandomPerspective(0.3, 0.3),
    torchvision.transforms.RandomRotation(degrees=20),
    torchvision.transforms.RandomHorizontalFlip(p=0.3),
    torchvision.transforms.ToTensor()
])

valid_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
])


train_dataset   = torchvision.datasets.ImageFolder(TRAIN_DIR, transform= train_transforms)
valid_dataset   = torchvision.datasets.ImageFolder(VAL_DIR, transform= valid_transforms)
test_dataset   = torchvision.datasets.ImageFolder(TEST_DIR, transform= valid_transforms)


# Create data loaders
train_loader = torch.utils.data.DataLoader(
    dataset     = train_dataset,
    batch_size  = config['batch_size'],
    shuffle     = True,
    num_workers = 2,
    pin_memory  = True
)

valid_loader = torch.utils.data.DataLoader(
    dataset     = valid_dataset,
    batch_size  = config['batch_size'],
    shuffle     = False,
    num_workers = 2
)

test_loader = torch.utils.data.DataLoader(
    dataset     = test_dataset,
    batch_size  = config['batch_size'],
    shuffle     = False,
    num_workers = 2
)

In [None]:
print(train_dataset.classes)
print(valid_dataset.classes)
print(test_dataset.classes)

In [None]:
print("Number of classes    : ", len(train_dataset.classes))
print("No. of train images  : ", train_dataset.__len__())
print("Shape of image       : ", train_dataset[0][0].shape)
print("Batch size           : ", config['batch_size'])
print("Train batches        : ", train_loader.__len__())
print("Val batches          : ", valid_loader.__len__())

# Visualization

In [None]:
# Visualize a few images in the dataset

r, c = [5, 5]
fig, ax = plt.subplots(r, c, figsize=(15, 15))

k = 0
dtl = torch.utils.data.DataLoader(
    dataset=torchvision.datasets.ImageFolder(TRAIN_DIR, transform=train_transforms),
    batch_size=config['batch_size'],
    shuffle=True,
)

for data in dtl:
    x, y = data
    break

for i in range(r):
    for j in range(c):
        img = x[k].numpy().transpose(1, 2, 0)
        ax[i, j].imshow(img)
        ax[i, j].axis('off')
        
        # Get the class label for the current image
        class_label = dtl.dataset.classes[y[k].item()]  # Assuming y[k] contains class indices
        ax[i, j].set_title(class_label)  # Set the title to the class label
        
        k += 1
        if k >= r * c:  # Break if all subplots are filled
            break
    if k >= r * c:
        break

del dtl

# Network architecture

In [None]:
# !pip install timm

In [None]:
import timm

In [None]:
class SeResNetNetwork(nn.Module):
    def __init__(self, num_classes=2):
        super(SeResNetNetwork, self).__init__()
        self.backbone = timm.create_model("seresnet34", pretrained=False)
        num_features = self.backbone.fc.in_features  # Access the default module
        self.backbone.fc = nn.Linear(num_features, num_classes)

    def forward(self, x, return_feats=False):
        if return_feats:
            feats = self.backbone(x)
            return feats
        else:
            out = self.backbone(x)
            return out

# Create an instance of the model and move it to the GPU if available
model = SeResNetNetwork().to("cuda" if torch.cuda.is_available() else "cpu")

# Calculate the total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params}")

In [None]:
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.15)
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'],betas=(0.5, 0.999),weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)
scaler = torch.cuda.amp.GradScaler() 

In [None]:
# Track the initial learning rate
initial_lr = optimizer.param_groups[0]['lr']
print(f"Initial learning rate: {initial_lr}")

# Define a function to check if the learning rate changes
def check_lr_change():
    current_lr = optimizer.param_groups[0]['lr']
    if current_lr != initial_lr:
        print(f"Learning rate changed to: {current_lr}")# Track the initial learning rate

## Train

In [None]:
def train(model, dataloader, optimizer, criterion):

    model.train()

    # Progress Bar
    batch_bar   = tqdm(total=len(dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train', ncols=5)

    num_correct = 0
    total_loss  = 0

    for i, (images, labels) in enumerate(dataloader):

        optimizer.zero_grad() # Zero gradients

        images, labels = images.to(DEVICE), labels.to(DEVICE)

        with torch.cuda.amp.autocast(): # This implements mixed precision. Thats it!
            outputs = model(images)
            loss    = criterion(outputs, labels)

        # Update no. of correct predictions & loss as we iterate
        num_correct     += int((torch.argmax(outputs, axis=1) == labels).sum())
        total_loss      += float(loss.item())

        # tqdm lets you add some details so you can monitor training as you train.
        batch_bar.set_postfix(
            acc         = "{:.04f}%".format(100 * num_correct / (config['batch_size']*(i + 1))),
            loss        = "{:.04f}".format(float(total_loss / (i + 1))),
            num_correct = num_correct,
            lr          = "{:.04f}".format(float(optimizer.param_groups[0]['lr']))
        )

        scaler.scale(loss).backward() # a replacement for loss.backward()
        scaler.step(optimizer) # a replacement for optimizer.step()
        scaler.update()

        batch_bar.update() # Update tqdm bar

    batch_bar.close() # close the tqdm bar

    acc         = 100 * num_correct / (config['batch_size']* len(dataloader))
    total_loss  = float(total_loss / len(dataloader))

    return acc, total_loss

In [None]:
def validate(model, dataloader, criterion):

    model.eval()
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc='Val', ncols=5)

    num_correct = 0.0
    total_loss = 0.0

    for i, (images, labels) in enumerate(dataloader):

        # Move images to device
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        # Get model outputs
        with torch.inference_mode():
            outputs = model(images)
            loss = criterion(outputs, labels)

        num_correct += int((torch.argmax(outputs, axis=1) == labels).sum())
        total_loss += float(loss.item())

        batch_bar.set_postfix(
            acc="{:.04f}%".format(100 * num_correct / (config['batch_size']*(i + 1))),
            loss="{:.04f}".format(float(total_loss / (i + 1))),
            num_correct=num_correct)

        batch_bar.update()

    batch_bar.close()
    acc = 100 * num_correct / (config['batch_size']* len(dataloader))
    total_loss = float(total_loss / len(dataloader))
    return acc, total_loss

In [None]:
gc.collect() # These commands help you when you face CUDA OOM error
torch.cuda.empty_cache()

## Wandb

In [None]:
wandb.login(key="Insert wandb API key") 

In [None]:
# Create your wandb run
run = wandb.init(
    name = "Resnet50 Pretrained", #
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
    project = "Final_project", 
    config = config ### Wandb Config for your run
)

In [None]:
best_valacc = 0

for epoch in range(config['epochs']):

    curr_lr = float(optimizer.param_groups[0]['lr'])

    train_acc, train_loss = train(model, train_loader, optimizer, criterion)

    print("\nEpoch {}/{}: \nTrain Acc {:.04f}%\t Train Loss {:.04f}\t Learning Rate {:.04f}".format(
        epoch + 1,
        config['epochs'],
        train_acc,
        train_loss,
        curr_lr))

    val_acc, val_loss = validate(model, valid_loader, criterion)

    print("Val Acc {:.04f}%\t Val Loss {:.04f}".format(val_acc, val_loss))

    wandb.log({"train_loss":train_loss, 'train_Acc': train_acc, 'validation_Acc':val_acc,
               'validation_loss': val_loss, "learning_Rate": curr_lr})

    scheduler.step(val_loss)  # Adjust the learning rate based on validation loss
    check_lr_change()
    
    

    # Save model in a drive location if val_acc is better than the best recorded val_acc
    if val_acc >= best_valacc:
        print("Saving model")
        torch.save({'model_state_dict':model.state_dict(),
                    'optimizer_state_dict':optimizer.state_dict(),
                    'scheduler_state_dict':scheduler.state_dict(),
                    'val_acc': val_acc,
                    'epoch': epoch}, './checkpoint.pth')
        best_valacc = val_acc
        wandb.save('checkpoint.pth')

run.finish()

In [None]:
checkpoint = torch.load('./checkpoint.pth')

# Load model state dict
model.load_state_dict(checkpoint['model_state_dict'])
val_acc = checkpoint['val_acc']

## Testing

In [None]:
def test(model,dataloader):

  model.eval()
  batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc='Test')
  test_results = []

  for i, (images) in enumerate(dataloader):
      # predicting on the test set.
      images = images.to(DEVICE)

      with torch.inference_mode():
        outputs = model(images)

      outputs = torch.argmax(outputs, axis=1).detach().cpu().numpy().tolist()
      test_results.extend(outputs)

      batch_bar.update()

  batch_bar.close()
  return test_results

In [None]:


def test(model, dataloader):
    model.eval()
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc='Test')
    test_results = []
    true_labels = []

    for i, (images, labels) in enumerate(dataloader):  # assuming labels are present in the dataloader
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        with torch.inference_mode():
            outputs = model(images)

        predicted = torch.argmax(outputs, axis=1).detach().cpu().numpy().tolist()
        test_results.extend(predicted)
        true_labels.extend(labels.cpu().numpy().tolist())

        batch_bar.update()

    batch_bar.close()

    accuracy = accuracy_score(true_labels, test_results)
    return accuracy

In [None]:
test_results = test(model, test_loader)

In [None]:
print (f"The test accuracy is {test_results*100} %")