In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
import torchvision
import torchvision.transforms as transforms
import tqdm
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from sklearn.metrics import roc_curve, accuracy_score, f1_score, auc, precision_recall_curve
from statistics import mean
from captum.attr import Saliency, DeepLift, IntegratedGradients

### Configuration

In [11]:
batch_size = 64
classes = ("plane", "car", "bird", "cat",
           "deer", "dog", "frog", "horse", "ship", "truck")
max_lr = 1e-3
n_epochs = 1
model_name = "CNN-2D"

### Dataset

In [12]:
transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_dataset_loader = torch.utils.data.DataLoader(train_dataset, 
                                                   batch_size=batch_size, 
                                                   shuffle=True)                                   

test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
test_dataset_loader = torch.utils.data.DataLoader(test_dataset, 
                                                  batch_size=batch_size, 
                                                  shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


### Model

In [13]:
class CNN_Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1 = nn.Conv2d(3, 6, 5, stride=1, padding="same")
        self.conv_2 = nn.Conv2d(6, 16, 5, stride=1, padding="same")
        self.maxpool = nn.MaxPool2d(2, 2)
        self.linear_1 = nn.Linear(16 * 8 * 8, 120)
        self.linear_2 = nn.Linear(120, 84)
        self.linear_3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = self.maxpool(F.relu(self.conv_1(x)))
        x = self.maxpool(F.relu(self.conv_2(x)))
        
        x = torch.flatten(x, 1) # flatten all dimension except the first one which corresponds to batch
        
        x = F.relu(self.linear_1(x))
        x = F.relu(self.linear_2(x))
        
        x = self.linear_3(x)
        return x

In [14]:
model = CNN_Model()

### Loss and Optimizer

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=max_lr)
lr_scheduler = OneCycleLR(
                    optimizer=optimizer,
                    max_lr=max_lr,
                    epochs=n_epochs,
                    steps_per_epoch=len(train_dataset_loader),
                    pct_start=0.1,
                    anneal_strategy='cos',
                    div_factor=25.0,
                    final_div_factor=10000.0)
tbw = SummaryWriter() # tensorboard summary writer

### Training

In [16]:
def train():
    model.train_iter = 0
    model.val_iter = 0
    
    for epoch in range(n_epochs):
        model.train()
        for _, data in enumerate(pbar := tqdm.tqdm(train_dataset_loader, 0)):
            model.train_iter += 1 
            # get input and labels; data is a list of [(inputs, labels)]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            outputs = model(inputs)
            # loss
            loss = criterion(outputs, labels)
            # backward propagate loss
            loss.backward()
            # update parameters
            optimizer.step()
            lr_scheduler.step()

            # training_loss
            train_loss = loss.item()
            # tensorboard + logs
            tbw.add_scalar(f"{model_name}/training-loss", float(train_loss), model.train_iter)
            pbar.set_description(f"{model_name}/training-loss={train_loss}, steps={model.train_iter}, epoch={epoch+1}")
            
        validate(model, epoch)
        
    return model, validate(model, epoch)
    
def validate(model, epoch=0):
    with torch.no_grad():
        model.eval()

        results = []

        for _, data in enumerate(pbar := tqdm.tqdm(test_dataset_loader, 0)):
            # get input and labels; data is a list of [(inputs, labels)]
            inputs, labels = data

            output = model(inputs)
            # loss
            loss = criterion(output, labels)
            curr_val_loss = loss.item()
            model.val_iter += 1

            # tensorboard + logs
            tbw.add_scalar(f"{model_name}/validation-loss", float(curr_val_loss), model.val_iter)
            pbar.set_description(f"{model_name}/validation-loss={curr_val_loss}, steps={model.val_iter}, epoch={epoch+1}")

            # to get probabilities of the output
            output = F.softmax(output, dim=-1)
            result_df = pd.DataFrame(output.cpu().numpy())
            result_df["y_true"] = labels.cpu().numpy()
            results.append(result_df)
    
    return pd.concat(results, ignore_index=True)
        


In [17]:
model, results = train()

CNN-2D/training-loss=2.3232076168060303, steps=782, epoch=1: 100%|████████████████████████| 782/782 [02:04<00:00,  6.26it/s]
CNN-2D/validation-loss=2.29150128364563, steps=157, epoch=1: 100%|████████████████████████| 157/157 [00:10<00:00, 15.36it/s]
CNN-2D/validation-loss=2.29150128364563, steps=314, epoch=1: 100%|████████████████████████| 157/157 [00:10<00:00, 15.32it/s]


In [19]:
results

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,y_true
0,0.099236,0.094794,0.093328,0.099401,0.101005,0.100911,0.108543,0.102514,0.105266,0.095001,3
1,0.099690,0.094495,0.092473,0.098701,0.100804,0.101270,0.108735,0.102835,0.106296,0.094700,8
2,0.099770,0.094113,0.093175,0.099116,0.100784,0.101022,0.108739,0.102015,0.106203,0.095062,8
3,0.100045,0.094511,0.093239,0.098933,0.101167,0.100464,0.108663,0.102330,0.105852,0.094796,0
4,0.098358,0.094266,0.093771,0.099647,0.101853,0.100885,0.108963,0.103319,0.104864,0.094074,6
...,...,...,...,...,...,...,...,...,...,...,...
9995,0.099488,0.094419,0.093041,0.099210,0.101168,0.101254,0.109353,0.103042,0.105253,0.093771,8
9996,0.098113,0.094693,0.092736,0.099841,0.101368,0.100895,0.109866,0.103509,0.104664,0.094315,3
9997,0.096948,0.094890,0.092133,0.100815,0.100713,0.102226,0.109501,0.104584,0.103807,0.094383,5
9998,0.097994,0.094658,0.093248,0.100330,0.100930,0.101213,0.109010,0.103006,0.104937,0.094674,1


In [20]:
auprcs = []
for i in range(10):
    precision, recall, _ = precision_recall_curve(y_true=results["y_true"].values, probas_pred=results[i].values, pos_label=i)
    auprc = auc(recall, precision)
    print(f"AUPRC for class {i} = {auprc}")
    auprcs.append(auprc)

macro_auprc = mean(auprcs)
print(f"Macro AUPRC = {macro_auprc}")

AUPRC for class 0 = 0.25787064348973504
AUPRC for class 1 = 0.1701077173030176
AUPRC for class 2 = 0.1736992011648282
AUPRC for class 3 = 0.12486837002079033
AUPRC for class 4 = 0.17783742889893367
AUPRC for class 5 = 0.10462516444183038
AUPRC for class 6 = 0.14732852023495913
AUPRC for class 7 = 0.1346943178682629
AUPRC for class 8 = 0.26444634546603746
AUPRC for class 9 = 0.14328781051171552
Macro AUPRC = 0.16987655194001103


In [27]:
dl = Saliency(model)
inputs = next(iter(test_dataset_loader))[0][0].unsqueeze(dim=0)


In [28]:
attribution = dl.attribute(inputs, target=2)



In [30]:
attribution.shape

torch.Size([1, 3, 32, 32])