In [None]:
import argparse
from tqdm import tqdm
from tensorboardX import SummaryWriter
from PIL import Image
import glob
import pandas as pd
import torch,torchvision
from matplotlib import pyplot as plt
from torchvision import transforms
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
import datetime  
from torch.utils import data
import random
import numpy as np
import time
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import cv2
import numpy

In [None]:
%matplotlib inline

In [None]:
torch.backends.cudnn.benchmark = True
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)


In [None]:
#Enable cuda
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#print(device)
#construct dataset from pandas
path = #training data path
files = glob.glob(path+'/*.tif')
df = pd.DataFrame(files)
df.columns = ['Path']
df.loc[df['Path'].str.contains('LABELAB30'),'Label'] = 1
df.loc[df['Path'].str.contains('LABELVEH00'),'Label'] = 0

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform
    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        #I here convert the images into grayscale.
        img = Image.open(row["Path"]).convert('L')
        label = int(row["Label"])
        if self.transform:
            img = self.transform(img)
        return (
            img,
            label,
        )
    
class DatasetFromSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

tsfm = transforms.Compose([transforms.Resize(512)])
dataset = MyDataset(df,transform = tsfm)

#transform: the normalize parameters are calculated over the *original* whole dataset
tsfm_train = transforms.Compose([transforms.RandomCrop(448),
                                 transforms.RandomRotation(30),
                                 transforms.RandomHorizontalFlip(p=0.5),
                                 transforms.RandomVerticalFlip(p=0.5),
                                 transforms.ColorJitter(brightness=0.2, contrast=0.2),
                                 transforms.ToTensor(),
                                 transforms.Normalize((.6662),(.2179))
                                ])
tsfm_val = transforms.Compose([transforms.CenterCrop(448),
                                 transforms.ToTensor(),
                                 transforms.Normalize((.6662),(.2179))
                              ])
#divide train,val,test sets
n = len(dataset)
n_train = int(0.8*n)
n_val = int(0.1*n)
n_test = int(0.1*n)

train_set, val_set, test_set = data.random_split(dataset, [n_train, n_val, n_test])

train_data = DatasetFromSubset(train_set,transform = tsfm_train)
val_data = DatasetFromSubset(val_set,transform = tsfm_val)
test_data = DatasetFromSubset(test_set,transform = tsfm_val)

#Dataloaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=8,
                                           shuffle=True,pin_memory = True)  # <1>
val_loader = torch.utils.data.DataLoader(val_data, batch_size=8,
shuffle=True,pin_memory = True)

test_loader = torch.utils.data.DataLoader(test_data, batch_size=8,
shuffle=True,pin_memory = True)

In [None]:
#Use resnet18 and change the last layer's output to 2 dimension
resnet18 = torchvision.models.resnet18(pretrained=True)
new_fc= nn.Linear(in_features=512, out_features=2, bias=True,)
resnet18.fc = new_fc
torch.nn.init.xavier_uniform(resnet18.fc.weight)
#set model onto proper device
resnet18 = resnet18.to(device)
# Freeze selected parts of the architecture
for parameter in resnet18.parameters():
    parameter.requires_grad = False
    
for parameter in resnet18.fc.parameters():
    parameter.requires_grad = True

#for parameter in resnet18.layer4.parameters():
    #parameter.requires_grad = True
#change the first layer with 1-channel input and initalize weights.
# try different aggregations
conv1_init_weight = resnet18.conv1.weight.mean(dim = 1).unsqueeze(dim = 1)
resnet18.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
resnet18.conv1.weight.data = conv1_init_weight    

In [None]:
start_epoch = 0
layer_count = 0
for parameter in resnet18.children():
    parameter.requires_grad = False
    layer_count+=1
layer_count
# Avoid empty optimizer when initialization
for parameter in resnet18.fc.parameters():
    parameter.requires_grad = True
for parameter in resnet18.layer4.parameters():
    parameter.requires_grad = True


In [None]:

for epoch in range(start_epoch+1, 51):
    # set models to train mode
    resnet18.train()
    print(epoch)
    i = layer_count
    for child in resnet18.children():
        i-=1
        if i < (epoch // 10): 
            for parameter in child.parameters():
                parameter.requires_grad = True
        else:
            for parameter in child.parameters():
                parameter.requires_grad = False
    #from the first epoch, unfreeze last two blocks
    for parameter in resnet18.fc.parameters():
        parameter.requires_grad = True
    for parameter in resnet18.layer4.parameters():
        parameter.requires_grad = True
    #print(child)
    #print(parameter.requires_grad)
    # use prefetch_generator and tqdm for iterating through data
    pbar = tqdm(enumerate(train_loader),
                total=len(train_loader),bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
    start_time = time.time()
    
    for i, data in pbar:  
        imgs , labels = data
        imgs = imgs.to(device)
        labels = labels.to(device)
        #track time
        prepare_time = start_time-time.time()
        
        outputs = resnet18(imgs)  
        loss = loss_fn(outputs, labels)  
        optimizer.zero_grad()  
        loss.backward()  
        optimizer.step()   

        # compute computation time and *compute_efficiency*
        process_time = start_time-time.time()-prepare_time
        compute_efficiency = process_time/(process_time+prepare_time)
        pbar.set_description(
            f'Compute efficiency: {compute_efficiency:.4f}, ' 
            f'loss: {loss.item():.4f},  epoch: {epoch}/{30}')
        start_time = time.time()
                # maybe do a test pass every N=1 epochs
    if epoch % 1 == 0:
        # bring models to evaluation mode
        resnet18.eval()
        correct = 0
        total = 0
        pbar = tqdm(enumerate(val_loader),
                total=len(val_loader),bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') 
        with torch.no_grad():
            for i, data in pbar:
                # data preparation
                imgs , labels = data
                imgs = imgs.to(device)
                labels = labels.to(device)
                out = resnet18(imgs)
                _, predicted = torch.max(out.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy on test set: {100*correct/total:.2f}')
        # udpate tensorboardX
        writer.add_scalar('Accuracy',correct/total, epoch)

# Test on unseen data

In [None]:
#test set, typo on previous ones, they should be *val* 
resnet18.eval()
correct = 0
total = 0
pbar = tqdm(enumerate(test_loader),
        total=len(test_loader),bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}') 
with torch.no_grad():
    for i, data in pbar:
        # data preparation
        imgs , labels = data
        imgs = imgs.to(device)
        labels = labels.to(device)
        out = resnet18(imgs)
        _, predicted = torch.max(out.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy on test set: {100*correct/total:.2f}')

# Grad-CAM on monochrome input

In [None]:
resnet18 = torch.load("D:/CellClass/resnet18_weights.pth")

In [None]:
#use new dataset in order to utilize cv2 image
tsfm_val_grad = transforms.Compose([transforms.CenterCrop(448),
                                 ])
val_grad = DatasetFromSubset(val_set,transform = tsfm_val_grad)
index = 100
img1,label1 = val_data[index]
img2,label2 = val_grad[index]
plt.subplot(1,2,1)
plt.imshow(img1.permute(1,2,0))
plt.subplot(1,2,2)
plt.imshow(img2)
rgb_image =numpy.array(img2)
rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_GRAY2RGB)
rgb_image = rgb_image/255
rgb_image = rgb_image.astype('float32')

In [None]:
target_layer = [resnet18.layer1[-1],resnet18.layer2[-1],
                resnet18.layer3[-1],resnet18.layer4[-1],
               ]
layer_name = ['Layer1','Layer2','Layer3','Layer4']

In [None]:
resnet18.layer2[0]

In [None]:
target_layer = [resnet18.layer2[0].conv1,resnet18.layer2[-1].conv2,
                resnet18.layer3[0].conv1,resnet18.layer3[-1].conv2,
                resnet18.layer4[0].conv1,resnet18.layer4[0].conv2,
                resnet18.layer4[-1].conv1,resnet18.layer4[-1].conv2,
               ]
layer_name = ['L2,B1,Conv1','L2,B2,Conv2','L3,B1,Conv1','L3,B2,Conv2','L4,B1,Conv1','L4,B1,Conv2','L4,B2,Conv1','L4,B2,Conv2']

In [None]:
input_tensor = img1.unsqueeze(0).to(device)
plt.figure(figsize=(22, 10)) 
for i,layer in enumerate(target_layer):  
    cam =XGradCAM(model=resnet18, target_layer=layer)
    target_category = 1
    grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
    grayscale_cam= grayscale_cam[0,:]
    visualization = show_cam_on_image(rgb_image,grayscale_cam)
    plt.subplot(2,4,i+1)
    plt.axis('off')
    plt.title(layer_name[i])
    plt.imshow(visualization)
plt.tight_layout()
# Due to my memory, I can only plot these few plots...
#plt.savefig('Xgrad_2.png')

In [None]:
input_tensor = img1.unsqueeze(0).to(device)
plt.figure(figsize=(13, 22)) 
for i,layer in enumerate(target_layer):  
    cam =GradCAM(model=resnet18, target_layer=layer)
    target_category = label1
    grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
    grayscale_cam= grayscale_cam[0,:]
    visualization = show_cam_on_image(rgb_image,grayscale_cam)
    plt.subplot(4,2,i+1)
    plt.title(layer_name[i])
    plt.imshow(visualization)
# Due to my memory, I can only plot these few plots...
plt.savefig('grad.png')