In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import time
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import models

In [2]:
# check GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
from col_mnist import ColMNIST
trainloader = torch.utils.data.DataLoader(
  ColMNIST('data/mnist', train=True, download=True,
                             transform=torchvision.transforms.Compose([#torchvision.transforms.Resize((224, 224)),
                               torchvision.transforms.ToTensor(),
                             ])),
  batch_size=32, shuffle=True)

testloader = torch.utils.data.DataLoader(
  ColMNIST('data/mnist', train=False, download=True,
                             transform=torchvision.transforms.Compose([#torchvision.transforms.Resize((224, 224)),
                               torchvision.transforms.ToTensor()
                             ])),
  batch_size=32, shuffle=True)

In [4]:
vgg16 = models.vgg16(pretrained=True)
vgg16.to(device)
print(vgg16)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [5]:
# change the number of classes 
vgg16.classifier[6].out_features = 30
# freeze convolution weights
for param in vgg16.features.parameters():
    param.requires_grad = True

In [6]:
from models import DisentangledLinear, BlockDropout

n_classes = 30

vgg16.classifier[6] = DisentangledLinear(vgg16.classifier[6].in_features, n_classes).to(device)
vgg16.classifier[5] = BlockDropout(vgg16.classifier[6], ncc=2, apply_to="in")

for param in vgg16.features.parameters():
    param.requires_grad = True

In [7]:
# optimizer
optimizer = optim.SGD(vgg16.classifier.parameters(), lr=0.001, momentum=0.9)
# loss function
criterion = nn.CrossEntropyLoss()

In [8]:
# validation function
def validate(model, test_dataloader):
    model.eval()
    val_running_loss = 0.0
    val_running_correct = 0
    for int, data in enumerate(test_dataloader):
        data, (target, dclr_idx, bclr_idx) = data[0], data[1]
        target += bclr_idx*10
        data = data.to(device)
        target = target.to(device)
        output = model(data)
        loss = criterion(output, target)
        
        val_running_loss += loss.item()
        _, preds = torch.max(output.data, 1)
        val_running_correct += (preds == target).sum().item()
    
    val_loss = val_running_loss/len(test_dataloader.dataset)
    val_accuracy = 100. * val_running_correct/len(test_dataloader.dataset)
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}')
    
    return val_loss, val_accuracy

In [9]:
# training function
def fit(model, train_dataloader):
    model.train()
    train_running_loss = 0.0
    train_running_correct = 0
    for i, data in enumerate(train_dataloader):
        data, (target, dclr_idx, bclr_idx) = data[0], data[1]
        target += bclr_idx*10
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        train_running_loss += loss.item()
        _, preds = torch.max(output.data, 1)
        train_running_correct += (preds == target).sum().item()
        loss.backward()
        optimizer.step()
    train_loss = train_running_loss/len(train_dataloader.dataset)
    train_accuracy = 100. * train_running_correct/len(train_dataloader.dataset)
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}')
    
    return train_loss, train_accuracy

In [None]:
train_loss , train_accuracy = [], []
val_loss , val_accuracy = [], []
start = time.time()
for epoch in range(30):
    print(epoch)
    train_epoch_loss, train_epoch_accuracy = fit(vgg16, trainloader)
    val_epoch_loss, val_epoch_accuracy = validate(vgg16, testloader)
    train_loss.append(train_epoch_loss)
    train_accuracy.append(train_epoch_accuracy)
    val_loss.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)
    torch.save(vgg16, 'vgg16disen_e'+str(epoch)+'.pt')
end = time.time()
print((end-start)/60, 'minutes')

0
Train Loss: 0.0309, Train Acc: 69.28
Val Loss: 0.0183, Val Acc: 83.73
1
Train Loss: 0.0167, Train Acc: 82.73
Val Loss: 0.0136, Val Acc: 88.25
2
Train Loss: 0.0138, Train Acc: 85.64
Val Loss: 0.0133, Val Acc: 89.00
3
Train Loss: 0.0122, Train Acc: 87.30
Val Loss: 0.0111, Val Acc: 90.43
4
Train Loss: 0.0112, Train Acc: 88.38
Val Loss: 0.0112, Val Acc: 91.04
5
Train Loss: 0.0104, Train Acc: 89.11
Val Loss: 0.0098, Val Acc: 91.72
6
Train Loss: 0.0099, Train Acc: 89.74
Val Loss: 0.0094, Val Acc: 92.56
7
Train Loss: 0.0093, Train Acc: 90.25
Val Loss: 0.0085, Val Acc: 92.60
8
Train Loss: 0.0089, Train Acc: 90.67
Val Loss: 0.0098, Val Acc: 92.20
9
Train Loss: 0.0087, Train Acc: 90.97
Val Loss: 0.0088, Val Acc: 92.61
10
Train Loss: 0.0084, Train Acc: 91.18
Val Loss: 0.0084, Val Acc: 93.09
11
Train Loss: 0.0081, Train Acc: 91.52
Val Loss: 0.0082, Val Acc: 93.00
12
Train Loss: 0.0079, Train Acc: 91.69
Val Loss: 0.0080, Val Acc: 93.24
13
Train Loss: 0.0077, Train Acc: 91.96
Val Loss: 0.0092, Val

In [None]:
plt.figure(figsize=(10, 7))
plt.plot(train_accuracy, color='green', label='train accuracy')
plt.plot(val_accuracy, color='blue', label='validataion accuracy')
plt.legend()
plt.savefig('accuracy.png')
plt.show()

In [None]:
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(val_loss, color='red', label='validataion loss')
plt.legend()
plt.savefig('loss.png')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from col_mnist import ColMNIST

plt.rcParams["axes.grid"] = False
device = "cuda" if torch.cuda.is_available() else "cpu"

def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(testloader)
images, (target, dclr_idx, bclr_idx) = dataiter.next()
target += bclr_idx*10
imshow(torchvision.utils.make_grid(images))
img_shape = images[0].shape
print("Image shape: {}".format(img_shape))
print(target)

In [None]:
output = vgg16(images)
output.argmax(dim=1)

In [None]:
from explainn_code.grab_functions import db_from_dat_with_labels, write_pic_as_sets

tail = []
head = []

for i, data in enumerate(train_dataloader):
    data, (target, dclr_idx, bclr_idx) = data[0], data[1]
    target += bclr_idx*10
    data = data.to(device)
    target = target.to(device)
    output = model(data).cpu()

    head.append(vgg16.classifier[6].detach().cpu().numpy())
    tail.append(target.detach().cpu().numpy())
head = np.concatenate(head)
tail = np.concatenate(tail)

In [None]:
write_pic_as_sets(head, "vgg_head_blocked_v1.data")

In [None]:
write_pic_as_sets(tail "tail_v1.dat")