In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [4]:
import time

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image

In [3]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda', index=0)

# I - Vanilla classification with pretrained VGG

In [12]:
! wget "https://github.com/MegaloPat/DNN/blob/main/DNN/aligned.zip"

--2023-01-27 21:43:48--  https://github.com/MegaloPat/DNN/blob/main/DNN/aligned.zip
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘aligned.zip’

aligned.zip             [ <=>                ] 132.50K  --.-KB/s    in 0.01s   

2023-01-27 21:43:48 (12.1 MB/s) - ‘aligned.zip’ saved [135684]



In [9]:
! wget https://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/vgg_face_dag.pth

--2023-01-27 21:42:37--  https://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/vgg_face_dag.pth
Resolving www.robots.ox.ac.uk (www.robots.ox.ac.uk)... 129.67.94.2
Connecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 580015466 (553M)
Saving to: ‘vgg_face_dag.pth’


2023-01-27 21:43:19 (13.5 MB/s) - ‘vgg_face_dag.pth’ saved [580015466/580015466]



## Prepare dataset

### Unzip data

In [18]:
import zipfile

with zipfile.ZipFile("aligned.zip", 'r') as zip_ref:
    zip_ref.extractall("./aligned")

In [19]:
!mkdir aligned/train
!mkdir aligned/test
!mv aligned/aligned/train_* aligned/train
!mv aligned/aligned/test_* aligned/test

### Prepare csv labels

In [20]:
import csv
# Creation des labels sous formats csv

with open("list_patition_label.txt","r") as file :
    train_csv = open("train_list_label.csv","w",newline="")
    test_csv = open("test_list_label.csv","w",newline="")

    train_writer = csv.writer(train_csv)
    train_writer.writerow(["Filename", "Label"])
    
    test_writer = csv.writer(test_csv)
    test_writer.writerow(["Filename", "Label"])
    
    
    for line in file:
        filename, label = line.strip().split(" ")
        idx = filename.index(".jpg")
        filename = filename[:idx] + "_aligned" + filename[idx:]
        label = str(int(label) - 1)
        
        if "train" in filename :
            train_writer.writerow([filename, label])
        else :
            test_writer.writerow([filename, label])


### Preprocessing transform

In [21]:
trans = transforms.Compose([
    transforms.Lambda(lambda x: x.float()),
    transforms.Resize((224,224)),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip()
])

### Create dataloaders (split is train/test/val = 80/10/10)

In [22]:
import pandas as pd
import os
from torch.utils.data import Dataset
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [23]:
train_data = CustomImageDataset("train_list_label.csv","./aligned/train", transform=trans)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=4)
print(f"\nNb batches in train: {len(train_loader)}")


Nb batches in train: 758




In [24]:
from sklearn.model_selection import train_test_split
test_data = CustomImageDataset("test_list_label.csv","./aligned/test", transform=trans)

test_indices, val_indices = train_test_split(list(range(len(test_data.img_labels.Label))), test_size=0.5, stratify=test_data.img_labels.Label)

val_data = torch.utils.data.Subset(test_data, val_indices)
test_data = torch.utils.data.Subset(test_data, test_indices)


In [25]:
test_loader = DataLoader(test_data, batch_size=16, shuffle=True, num_workers=4)
print(f"\nNb batches in test: {len(test_loader)}")


Nb batches in test: 92




In [26]:
val_loader = DataLoader(val_data, batch_size=16, shuffle=True, num_workers=4)
print(f"\nNb batches in val: {len(val_loader)}")


Nb batches in val: 92


## VGG class

In [33]:

import torch
import torch.nn as nn


class Vgg(nn.Module):

    def __init__(self):
        super(Vgg, self).__init__()
        self.meta = {'mean': [129.186279296875, 104.76238250732422, 93.59396362304688],
                     'std': [1, 1, 1],
                     'imageSize': [224, 224, 3]}
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu1_1 = nn.ReLU(inplace=True)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu1_2 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu2_1 = nn.ReLU(inplace=True)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu2_2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu3_1 = nn.ReLU(inplace=True)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu3_2 = nn.ReLU(inplace=True)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu3_3 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu4_1 = nn.ReLU(inplace=True)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu4_2 = nn.ReLU(inplace=True)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu4_3 = nn.ReLU(inplace=True)
        self.pool4 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu5_1 = nn.ReLU(inplace=True)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu5_2 = nn.ReLU(inplace=True)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu5_3 = nn.ReLU(inplace=True)
        self.pool5 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.fc6 = nn.Linear(in_features=25088, out_features=4096, bias=True)
        self.relu6 = nn.ReLU(inplace=True)
        self.dropout6 = nn.Dropout(p=0.5)
        self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True)
        self.relu7 = nn.ReLU(inplace=True)
        self.dropout7 = nn.Dropout(p=0.5)
        self.fc8 = nn.Linear(in_features=4096, out_features=2622, bias=True)

    def forward(self, x0):
        x1 = self.conv1_1(x0)
        x2 = self.relu1_1(x1)
        x3 = self.conv1_2(x2)
        x4 = self.relu1_2(x3)
        x5 = self.pool1(x4)
        x6 = self.conv2_1(x5)
        x7 = self.relu2_1(x6)
        x8 = self.conv2_2(x7)
        x9 = self.relu2_2(x8)
        x10 = self.pool2(x9)
        x11 = self.conv3_1(x10)
        x12 = self.relu3_1(x11)
        x13 = self.conv3_2(x12)
        x14 = self.relu3_2(x13)
        x15 = self.conv3_3(x14)
        x16 = self.relu3_3(x15)
        x17 = self.pool3(x16)
        x18 = self.conv4_1(x17)
        x19 = self.relu4_1(x18)
        x20 = self.conv4_2(x19)
        x21 = self.relu4_2(x20)
        x22 = self.conv4_3(x21)
        x23 = self.relu4_3(x22)
        x24 = self.pool4(x23)
        x25 = self.conv5_1(x24)
        x26 = self.relu5_1(x25)
        x27 = self.conv5_2(x26)
        x28 = self.relu5_2(x27)
        x29 = self.conv5_3(x28)
        x30 = self.relu5_3(x29)
        x31_preflatten = self.pool5(x30)
        x31 = x31_preflatten.view(x31_preflatten.size(0), -1)
        x32 = self.fc6(x31)
        x33 = self.relu6(x32)
        x34 = self.dropout6(x33)
        x35 = self.fc7(x34)
        x36 = self.relu7(x35)
        x37 = self.dropout7(x36)
        x38 = self.fc8(x37)
        return x38

def vgg_face(weights_path=None, **kwargs):
    """
    load imported model instance

    Args:
        weights_path (str): If set, loads model weights from the given path
    """
    model = Vgg()
    if weights_path:
        state_dict = torch.load(weights_path)
        model.load_state_dict(state_dict)
    return model

### Load pretrained weights on vgg_face

In [34]:
vgg = vgg_face("vgg_face_dag.pth")
vgg.fc8 = nn.Linear(in_features=4096, out_features=7, bias=True)
vgg = vgg.to(device)

## Train

### Initial evaluation on val dataset

In [55]:
cross_entropy = nn.CrossEntropyLoss()
def eval_model(net, loader):
  net.eval()
  acc, loss = 0., 0.
  c = 0
  for x, y in loader:
    with torch.no_grad():
      # No need to compute gradient here thus we avoid storing intermediary activations
      
      logits = net(x.to(device)).cpu()

    loss += cross_entropy(logits, y).item()
    preds = logits.argmax(dim=1)
    acc += (preds.numpy() == y.numpy()).sum()
    c += len(x)
    break

  acc /= c
  loss /= len(loader)
  net.train()
  return acc, loss

In [56]:


initial_acc, initial_loss = eval_model(vgg, val_loader)
print(f"Initial accuracy/loss on val: {round(100 * initial_acc, 2)}/{round(initial_loss, 4)}")



ValueError: ignored

### Training

In [37]:
from torch.optim.lr_scheduler import PolynomialLR


optimizer = torch.optim.Adam(vgg.parameters(), lr=0.00005)
scheduler = PolynomialLR(optimizer, total_iters=75, power=1.0)

nb_epochs = 75

train_accs, train_losses = [], []
val_accs, val_losses = [], []

In [38]:
from tqdm import tqdm
best_acc = 0
for epoch in range(nb_epochs):
  with tqdm(train_loader, unit="batch") as tepoch:
    start = time.time()
    running_acc, running_loss = 0., 0.
    c = 0
    for x, y in tepoch:
      x, y = x.to(device), y.to(device)

      optimizer.zero_grad()  # Clear previous gradients
      logits = vgg(x)
      loss = cross_entropy(logits, y)
      loss.backward()  # Compute gradients
      optimizer.step()  # Update weights with gradients
      scheduler.step()

      running_acc += (logits.argmax(dim=1).cpu().numpy() == y.cpu().numpy()).sum()
      running_loss += loss.item()
      c += len(x)
      tepoch.set_postfix(loss=loss.item())

    train_acc, train_loss = running_acc / c, running_loss / len(train_loader)
    train_accs.append(train_acc)
    train_losses.append(train_loss)
    
    val_acc, val_loss = eval_model(vgg, val_loader, cross_entropy)
    if val_acc > best_acc:
      best_acc = val_acc
      torch.save(vgg.state_dict(),"vgg_best_param.pth")
    val_accs.append(val_acc)
    val_losses.append(val_loss)

    print(
        f"Epoch {epoch + 1}/{nb_epochs}, "
        f"train acc/loss: {round(100 * train_acc, 2)}/{round(train_loss, 4)}, "
        f"val acc/loss: {round(100 * val_acc, 2)}/{round(val_loss, 4)}, "
        f"time {int(time.time() - start)}s"
    )

100%|██████████| 758/758 [03:40<00:00,  3.44batch/s, loss=1.12]


Epoch 1/75, train acc/loss: 48.54/1.4031, val acc/loss: 68.75/0.0096, time 224s


100%|██████████| 758/758 [03:39<00:00,  3.45batch/s, loss=1.23]


Epoch 2/75, train acc/loss: 49.32/1.3716, val acc/loss: 56.25/0.0135, time 220s


100%|██████████| 758/758 [03:39<00:00,  3.45batch/s, loss=1.46]


Epoch 3/75, train acc/loss: 49.06/1.3741, val acc/loss: 31.25/0.0164, time 221s


100%|██████████| 758/758 [03:39<00:00,  3.45batch/s, loss=1.19]


Epoch 4/75, train acc/loss: 48.95/1.3735, val acc/loss: 56.25/0.0143, time 220s


100%|██████████| 758/758 [03:39<00:00,  3.45batch/s, loss=1.22]


Epoch 5/75, train acc/loss: 49.37/1.3782, val acc/loss: 43.75/0.0203, time 220s


100%|██████████| 758/758 [03:40<00:00,  3.44batch/s, loss=1.31]


Epoch 6/75, train acc/loss: 49.01/1.3786, val acc/loss: 68.75/0.0118, time 221s


100%|██████████| 758/758 [03:40<00:00,  3.43batch/s, loss=1.95]


Epoch 7/75, train acc/loss: 49.29/1.3739, val acc/loss: 68.75/0.0101, time 221s


100%|██████████| 758/758 [03:38<00:00,  3.46batch/s, loss=0.752]


Epoch 8/75, train acc/loss: 49.36/1.375, val acc/loss: 75.0/0.016, time 221s


  2%|▏         | 18/758 [00:05<04:02,  3.05batch/s, loss=1.52]


KeyboardInterrupt: ignored

In [None]:
plt.subplot(1, 2, 1)
plt.plot(list(range(nb_epochs)), train_accs, label="Train")
plt.plot(list(range(nb_epochs)), val_accs, label="Val")
plt.title("Accuracy")
plt.subplot(1, 2, 2)
plt.plot(list(range(nb_epochs)), train_losses, label="Train")
plt.plot(list(range(nb_epochs)), val_losses, label="Val")
plt.title("Loss")

### Testing

In [None]:
state_dict = torch.load("vgg_best_param.pth")
vgg.load_state_dict(state_dict)

In [None]:
test_acc, test_loss = eval_model(vgg, test_loader, cross_entropy)
test_acc, test_loss

# II - Vgg with Pal


In [39]:
import zipfile
with zipfile.ZipFile("landmark.zip", 'r') as zip_ref:
    zip_ref.extractall("./landmark")

In [40]:
!mkdir landmark/train
!mkdir landmark/test
!mv landmark/landmark/train_* landmark/train
!mv landmark/landmark/test_* landmark/test

In [75]:
class JoinImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, landmark_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.landmark_dir = landmark_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        land_path = os.path.join(self.landmark_dir, self.img_labels.iloc[idx, 0])
        landmark = read_image(land_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, landmark, label

In [76]:
train_data = JoinImageDataset("train_list_label.csv","./aligned/train", "./landmark/train", transform=trans)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=4)
print(f"\nNb batches in train: {len(train_loader)}")


Nb batches in train: 758


In [77]:
from sklearn.model_selection import train_test_split
test_data = JoinImageDataset("test_list_label.csv","./aligned/test","./landmark/test", transform=trans)

test_indices, val_indices = train_test_split(list(range(len(test_data.img_labels.Label))), test_size=0.5, stratify=test_data.img_labels.Label)

val_data = torch.utils.data.Subset(test_data, val_indices)
test_data = torch.utils.data.Subset(test_data, test_indices)

In [78]:
test_loader = DataLoader(test_data, batch_size=16, shuffle=True, num_workers=4)
print(f"\nNb batches in test: {len(test_loader)}")


Nb batches in test: 92




In [79]:
val_loader = DataLoader(val_data, batch_size=16, shuffle=True, num_workers=4)
print(f"\nNb batches in val: {len(val_loader)}")


Nb batches in val: 92


In [80]:
import torch
import torch.nn as nn


class VggPal(nn.Module):

    def __init__(self):
        super(VggPal, self).__init__()
        self.meta = {'mean': [129.186279296875, 104.76238250732422, 93.59396362304688],
                     'std': [1, 1, 1],
                     'imageSize': [224, 224, 3]}
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu1_1 = nn.ReLU(inplace=True)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu1_2 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu2_1 = nn.ReLU(inplace=True)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu2_2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu3_1 = nn.ReLU(inplace=True)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu3_2 = nn.ReLU(inplace=True)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu3_3 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu4_1 = nn.ReLU(inplace=True)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu4_2 = nn.ReLU(inplace=True)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu4_3 = nn.ReLU(inplace=True)
        self.pool4 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu5_1 = nn.ReLU(inplace=True)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu5_2 = nn.ReLU(inplace=True)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
        self.relu5_3 = nn.ReLU(inplace=True)
        self.pool5 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
        self.fc6 = nn.Linear(in_features=25088, out_features=4096, bias=True)
        self.relu6 = nn.ReLU(inplace=True)
        self.dropout6 = nn.Dropout(p=0.5)
        self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True)
        self.relu7 = nn.ReLU(inplace=True)
        self.dropout7 = nn.Dropout(p=0.5)
        self.fc8 = nn.Linear(in_features=4096, out_features=2622, bias=True)
        
    def forward(self, x0):
        x1 = self.conv1_1(x0)
        x2 = self.relu1_1(x1)
        x3 = self.conv1_2(x2)
        x4 = self.relu1_2(x3)
        x5 = self.pool1(x4)
        x6 = self.conv2_1(x5)
        x7 = self.relu2_1(x6)
        x8 = self.conv2_2(x7)
        x9 = self.relu2_2(x8)
        x10 = self.pool2(x9)
        x11 = self.conv3_1(x10)
        x12 = self.relu3_1(x11)
        x13 = self.conv3_2(x12)
        x14 = self.relu3_2(x13)
        x15 = self.conv3_3(x14)
        x16 = self.relu3_3(x15)
        x17 = self.pool3(x16)
        x18 = self.conv4_1(x17)
        x19 = self.relu4_1(x18)
        x20 = self.conv4_2(x19)
        x21 = self.relu4_2(x20)
        x22 = self.conv4_3(x21)
        x23 = self.relu4_3(x22)
        x24 = self.pool4(x23)
        x25 = self.conv5_1(x24)
        x26 = self.relu5_1(x25)
        x27 = self.conv5_2(x26)
        x28 = self.relu5_2(x27)
        x29 = self.conv5_3(x28)
        x30 = self.relu5_3(x29)
        x31_preflatten = self.pool5(x30)
        x31 = x31_preflatten.view(x31_preflatten.size(0), -1)
        x32 = self.fc6(x31)
        x33 = self.relu6(x32)
        x34 = self.dropout6(x33)
        x35 = self.fc7(x34)
        x36 = self.relu7(x35)
        x37 = self.dropout7(x36)
        x38 = self.fc8(x37)
        return x24,x25,x38

def vgg_palface(weights_path=None, **kwargs):
    """
    load imported model instance

    Args:
        weights_path (str): If set, loads model weights from the given path
    """
    model = VggPal()
    if weights_path:
        state_dict = torch.load(weights_path)
        model.load_state_dict(state_dict)
    return model

In [81]:
vggpal = vgg_palface("vgg_face_dag.pth")
vggpal.fc8 = nn.Linear(in_features=4096, out_features=7, bias=True)
vggpal = vggpal.to(device)

In [82]:
def Grad(inputs, outputs) :
    outputs_sum = outputs.sum()
    inputs.retain_grad()
    outputs.retain_grad()
    outputs_sum.backward(retain_graph=True)
    return torch.abs(inputs.grad)

def GradxInput(inputs, outputs) :
    Grad_val = Grad(inputs, outputs)
    return Grad_val * inputs

def pal_loss(inputs, outputs, prior, attribution_method, channel_strategy=None) :
    attribution_map = attribution_method(inputs, outputs)
    
    if channel_strategy == "half_mean" :
        nb_class = attribution_map.shape[1]
        attribution_map[:, nb_class/2:, :, :]
    
    if channel_strategy == "half_mean" or channel_strategy == "mean" :
        attribution_map = attribution_map.mean(1).unsqueeze(1)
    
    attribution_map_resize = transforms.Resize(attribution_map.shape[-2:])
    
    prior = attribution_map_resize(prior)
    
    std = attribution_map.view(attribution_map.size(0), -1).std(1)
    mean = attribution_map.view(attribution_map.size(0), -1).mean(1)
    
    
    res = (attribution_map - mean.view(-1, 1, 1, 1)) / std.view(-1, 1, 1, 1)
    res = res * prior.unsqueeze(1)
    
    res = res.view(res.size(0), -1).sum(1)
    res = -res
    return res.mean()

In [83]:
def tot_loss(li, lo, prior, logits, y):
  p_loss = pal_loss(li,lo,prior,GradxInput,"half_mean")
  ce_loss = cross_entropy(logits, y)
  return pal_loss + ce_loss

In [84]:
def eval_modelpal(net, loader):
  net.eval()
  acc, loss = 0., 0.
  c = 0
  for x, prior, y in loader:
    li, lo, logits = net(x.to(device))
    li, lo, logits = li.cpu(), lo.cpu(), logits.cpu()
    loss += tot_loss(li, lo, prior, logits, y).item()
    preds = logits.argmax(dim=1)
    acc += (preds.numpy() == y.numpy()).sum()
    c += len(x)
    break

  acc /= c
  loss /= len(loader)
  net.train()
  return acc, loss

In [85]:
initial_acc, initial_loss = eval_modelpal(vggpal, val_loader)
print(f"Initial accuracy/loss on val: {round(100 * initial_acc, 2)}/{round(initial_loss, 4)}")

TypeError: ignored

In [None]:
optimizer = torch.optim.Adam(vggpal.parameters(), lr=0.00005)
scheduler = PolynomialLR(optimizer, total_iters=75, power=1.0)

nb_epochs = 75

train_accs, train_losses = [], []
val_accs, val_losses = [], []

In [None]:
from tqdm import tqdm
best_acc = 0
for epoch in range(nb_epochs):
  with tqdm(train_loader, unit="batch") as tepoch:
    start = time.time()
    running_acc, running_loss = 0., 0.
    c = 0
    for x, prior, y in tepoch:
      x, prior, y = x.to(device), prior.to(device),y.to(device)

      optimizer.zero_grad()  # Clear previous gradients
      li, lo ,logits = vggpal(x)
      loss = tot_loss(li, lo, prior, logits, y)
      loss.backward()  # Compute gradients
      optimizer.step()  # Update weights with gradients
      scheduler.step()

      running_acc += (logits.argmax(dim=1).cpu().numpy() == y.cpu().numpy()).sum()
      running_loss += loss.item()
      c += len(x)
      tepoch.set_postfix(loss=loss.item())

    train_acc, train_loss = running_acc / c, running_loss / len(train_loader)
    train_accs.append(train_acc)
    train_losses.append(train_loss)
    
    val_acc, val_loss = eval_model(vggpal, val_loader, cross_entropy)
    if val_acc > best_acc:
      best_acc = val_acc
      torch.save(vggpal.state_dict(),"vggpal_best_param.pth")
    val_accs.append(val_acc)
    val_losses.append(val_loss)

    print(
        f"Epoch {epoch + 1}/{nb_epochs}, "
        f"train acc/loss: {round(100 * train_acc, 2)}/{round(train_loss, 4)}, "
        f"val acc/loss: {round(100 * val_acc, 2)}/{round(val_loss, 4)}, "
        f"time {int(time.time() - start)}s"
    )

In [None]:
plt.subplot(1, 2, 1)
plt.plot(list(range(nb_epochs)), train_accs, label="Train")
plt.plot(list(range(nb_epochs)), val_accs, label="Val")
plt.title("Accuracy")
plt.subplot(1, 2, 2)
plt.plot(list(range(nb_epochs)), train_losses, label="Train")
plt.plot(list(range(nb_epochs)), val_losses, label="Val")
plt.title("Loss")

In [None]:
state_dict = torch.load("vggpal_best_param.pth")
vgg.load_state_dict(state_dict)

In [None]:
test_acc, test_loss = eval_model(vgg, test_loader, cross_entropy)
test_acc, test_loss