In [None]:
import torchvision
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision import models,datasets
# from torchvision import transforms as T
from torch.utils import data
from torchvision.models import vgg19
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from torch import optim
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import cv2, glob, numpy as np, pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from glob import glob
!pip install torchsummary

In [None]:
!pip install -q kaggle

In [None]:
from google.colab import files

files.upload()

In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!ls ~/.kaggle
!chmod 600 /root/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d tongpython/cat-and-dog

In [None]:
!ls

In [None]:
!unzip cat-and-dog.zip

In [None]:
train_data_dir = '/content/training_set/training_set'
test_data_dir = '/content/test_set/test_set'

In [None]:
from torchvision import transfroms as T

trn_tfms = T.Compose([
    T.ToPILImage(),
    T.Resize((224, 224)),
    T.ColorJitter(brightness=(0.95,1.05), 
                  contrast=(0.95,1.05), 
                  saturation=(0.95,1.05), 
                  hue=0.05),
    T.RandomAffine(5, translate=(0.01,0.1)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], 
                std=[0.25, 0.25, 0.25]),
])

In [None]:
val_tfms = T.Compose([
    T.ToPILImage(),
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], 
                std=[0.25, 0.25, 0.25]),
])

In [None]:
from torch.utils.data import DataLoader, Dataset
class cats_dogs(Dataset):
  def __init__(self, folder, transform=None):
    cats = glob(folder+'/cats/*.jpg')
    dogs = glob(folder+'/dogs/*.jpg')
    self.fpaths = cats + dogs
    from random import shuffle, seed; seed(10); shuffle(self.fpaths)
    self.targets = [fpath.split('/')[-1].startswith('dog') for fpath in self.fpaths] # dog=1 & cat=0
    self.transfrom = transform
    logger.info(len(self))
  def __len__(self): 
    return len(self.fpaths)
  def __getitem__(self, ix):
    f = self.fpaths[ix]
    target = self.targets[ix]
    im = (cv2.imread(f)[:,:,::-1])
    im = cv2.resize(im, (224,224))
    img = torch.tensor(im/255).permute(2,0,1).to(device).float()
    clss = torch.tensor([target]).float().to(device)
    return img, clss
  def choose(self):
    return self[randint(len(self))]
  def collate_fn(self, batch):
    _imgs, classes = list(zip(*batch))
    if self.transform:
      imgs = [self.transform(img)[None] for img in _imgs]
    classes = [torch.tensor([id2int[clss]]) for clss in classes]
    imgs, classes = [torch.cat(i).to(device) for i in [imgs, classes]]
    return imgs, classes, _imgs

In [None]:
data = cats_dogs(train_data_dir)
im, label = data[200]

In [None]:
len(data)

In [None]:
plt.imshow(im.permute(1,2,0).cpu())
print(label)

In [None]:
def conv_layer(ni,no,kernel_size,stride=1):
    return nn.Sequential(
    nn.Conv2d(ni, no, kernel_size, stride),
    nn.ReLU(),
    nn.BatchNorm2d(no),
    nn.MaxPool2d(2)
    )

class CatnDogClassifier(nn.Module):
  def get_model():
    super().__init__()
    self.model = nn.Sequential(
    conv_layer(3, 64, 3),
    conv_layer(64, 512, 3),
    conv_layer(512, 512, 3),
    conv_layer(512, 512, 3),
    conv_layer(512, 512, 3),
    conv_layer(512, 512, 3),
    nn.Flatten(),
    nn.Linear(512, 1),
    nn.Sigmoid(),
    ).to(device)
    self.loss_fn = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3)
    return model, loss_fn, optimizer

In [None]:
class VGG(nn.Module):
  def __init__(self):
    super(VGG, self).__init__()
    self.vgg = vgg19(pretrained=True)
    self.features_conv = self.vgg.features[:36]
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dialation=1, ceil_mode=False)
    self.classifier = self.vgg.classifier
    self.gradients = None

  def activation_hook(self, grad):
    self.gradients = grad

  def forward(self, x):
    x = self.features_conv(x)
    h = x.register_hook(self.activations_hook)
    x = self.max_pool(x)
    x = x.view((1, -1))
    x = self.classifier(x)
    return x

  def get_activations_gradient(self):
    return self.gradients

  def get_activations(self, x):
    return self.features_conv(x)

In [None]:
vgg = VGG()
vgg.eval()
img, _ = next(iter(trn_dl))
pred = vgg(img).argmax(dim=1)

In [None]:
pred[:, 386].backward()
gradients = vgg.get_activations_gradient()
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])


In [None]:
from torchsummary import summary
model, loss_fn, optimizer = get_model()
summary(model, input_size=(3, 224, 224))

In [None]:
def train_batch(x, y, model, opt, loss_fn):
    prediction = model(x)
    batch_loss = loss_fn(prediction, y)
    batch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return batch_loss.item()

@torch.no_grad()
def accuracy(x, y, model):
    prediction = model(x)
    is_correct = (prediction > 0.5) == y
    return is_correct.cpu().numpy().tolist()

In [None]:
def get_data():
    train = cats_dogs(train_data_dir, transform=trn_tfms)
    trn_dl = DataLoader(train, batch_size=32, shuffle=True, collate_fn=train.collate_fn)
    val = cats_dogs(test_data_dir, transform=trn_tfms)
    val_dl = DataLoader(val, batch_size=32, shuffle=True, collate_fn=val.collate_fn)
    return trn_dl, val_dl

In [None]:
@torch.no_grad()
def val_loss(x, y, model):
    prediction = model(x)
    val_loss = loss_fn(prediction, y)
    return val_loss.item()

In [None]:
trn_dl, val_dl = get_data()
model, loss_fn, optimizer = get_model()

In [None]:
train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []
for epoch in range(10):
    
    print(epoch)
    train_epoch_losses, train_epoch_accuracies = [], []
    val_epoch_accuracies = []
    for ix, batch in enumerate(iter(trn_dl)):
        #print(ix)
        x, y = batch
        batch_loss = train_batch(x, y, model, optimizer, loss_fn)
        train_epoch_losses.append(batch_loss)        
    train_epoch_loss = np.array(train_epoch_losses).mean()

    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        is_correct = accuracy(x, y, model)
        train_epoch_accuracies.extend(is_correct)
    train_epoch_accuracy = np.mean(train_epoch_accuracies)

    for ix, batch in enumerate(iter(val_dl)):
        x, y = batch
        val_is_correct = accuracy(x, y, model)
        val_epoch_accuracies.extend(val_is_correct)
        #validation_loss = val_loss(x, y, model)
    val_epoch_accuracy = np.mean(val_epoch_accuracies)

    train_losses.append(train_epoch_loss)
    train_accuracies.append(train_epoch_accuracy)
    #val_losses.append(validation_loss)
    val_accuracies.append(val_epoch_accuracy)

In [None]:
im2fmap = nn.Sequential(*(list(model.model[:5].children()) + list(model.model[5][:2].childeren())))

In [None]:
def im2gradCAM(x):
    model.eval()
    logits = model(x)
    heatmaps = []
    activations = im2fmap(x)
    print(activations.shape)
    pred = logits.max(-1)[-1]
    # get the model's prediction
    model.zero_grad()
    # compute gradients with respect to model's most confident logit
    logits[0,pred].backward(retain_graph=True)
    # get the gradients at the required featuremap location
    # and take the avg gradient for every featuremap
    pooled_grads = model.model[-6][1].weight.grad.data.mean((1,2,3))
    # multiply each activation map with corresponding gradient average
    for i in range(activations.shape[1]):
        activations[:,i,:,:] *= pooled_grads[i]
    # take the mean of all weighted activation maps
    # (that has been weighted by avg. grad at each fmap)
    heatmap = torch.mean(activations, dim=1)[0].cpu().detach()
    return heatmap, 'Uninfected' if pred.item() else 'Parasitized'

In [None]:
SZ = 224
def upsampleHeatmap(map, img):
    m,M = map.min(), map.max()
    map = 255 * ((map-m) / (M-m))
    map = np.uint8(map)
    map = cv2.resize(map, (SZ,SZ))
    map = cv2.applyColorMap(255-map, cv2.COLORMAP_JET)
    map = np.uint8(map)
    map = np.uint8(map*0.7 + img*0.3)
    return map

In [None]:
N = 20
_val_dl = DataLoader(val_ds, batch_size=N, shuffle=True, collate_fn=val_ds.collate_fn)
x,y,z = next(iter(_val_dl))

for i in range(N):
    image = resize(z[i], SZ)
    heatmap, pred = im2gradCAM(x[i:i+1])
    if(pred=='Uninfected'):
        continue
    heatmap = upsampleHeatmap(heatmap, image)
    subplots([image, heatmap], nc=2, figsize=(5,3), suptitle=pred)