In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt 
import numpy as np
import seaborn as sn
import pandas as pd 
from sklearn.metrics import classification_report,matthews_corrcoef,confusion_matrix
from tqdm import tqdm
from Models import res50
import os 

In [None]:
def do_prediction(mdl,dset,btch_size,dvc):
    test_dataloader = DataLoader(dataset=dset,num_workers=8,shuffle=False,batch_size=btch_size)
    y_prd = []
    y_tst = []
    for images,labels in tqdm(test_dataloader,desc= 'test_data'):
        with torch.no_grad():
            images = images.to(dvc)
            y = mdl(images)
            y_hat = torch.argmax(y, dim=1)
            pred = y_hat.cpu().tolist()
        labels = labels.tolist()
        labels = [lbl for lbl in labels]
        pred = [prd for prd in pred]
        y_tst.extend(labels)
        y_prd.extend(pred)
    return y_tst,y_prd

Testing the model trained on the real data on the external real test set

In [None]:
current_directory = os.getcwd()
path = os.path.abspath(os.path.join(current_directory,'..','Checkpoint-classifier','lightning_logs','GTEX256_leaveslides','checkpoints','epoch=79-step=83680.ckpt'))
data_dir = os.path.abspath(os.path.join(current_directory,'..','Test'))

model = res50.load_from_checkpoint(path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
batch_size = 8
transform = transforms.Compose([
                transforms.Resize(256),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.7358, 0.4814, 0.6079], 
                            std= [0.1314, 0.1863, 0.1465]
                            )
                ])
dataset = ImageFolder(root= data_dir,transform= transform)
class_labels = dataset.classes
num_classes = len(class_labels)
y_test,y_pred = do_prediction(model,dataset,batch_size,device)
print('\nClassification Report\n')
print(classification_report(np.array(y_test), np.array(y_pred), target_names= class_labels))
mcc = matthews_corrcoef(np.array(y_test), np.array(y_pred))
print('\nMatthews Correlation Coefficient\n')
print(mcc)
cf = {}
cf['real-real'] = confusion_matrix(np.array(y_test), np.array(y_pred))


Testing the model trained on the real data on the fake data

In [None]:
data_dir = os.path.abspath(os.path.join(current_directory,'..','Generated-data','synthetic_tiles_512TO256_GTEX'))

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.7358, 0.4814, 0.6079], 
                            std= [0.1314, 0.1863, 0.1465]
                            )
                ])
dataset = ImageFolder(root= data_dir,transform= transform)
class_labels = dataset.classes
num_classes = len(class_labels)
y_test,y_pred = do_prediction(model,dataset,batch_size,device)
print('\nClassification Report\n')
print(classification_report(np.array(y_test), np.array(y_pred), target_names= class_labels))
mcc = matthews_corrcoef(np.array(y_test), np.array(y_pred))
print('\nMatthews Correlation Coefficient\n')
print(mcc)
cf['real-fake'] = confusion_matrix(np.array(y_test), np.array(y_pred))

Testing the model trained on the fake data on the internal test set

In [None]:
path = os.path.abspath(os.path.join(current_directory,'..','Checkpoint-classifier','lightning_logs','256_GTEX_fakedata','checkpoints','epoch=79-step=52560.ckpt'))
data_dir = os.path.abspath(os.path.join(current_directory,'..','Test_internal'))

model = res50.load_from_checkpoint(path)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
batch_size = 8
transform =  transforms.Compose([
                transforms.Resize(256),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.6895, 0.4152, 0.5501], 
                            std= [0.1044, 0.1355, 0.1108]
                            )
                ])
dataset = ImageFolder(root= data_dir,transform= transform)
class_labels = dataset.classes
num_classes = len(class_labels)
y_test,y_pred = do_prediction(model,dataset,batch_size,device)
print('\nClassification Report\n')
print(classification_report(np.array(y_test), np.array(y_pred), target_names= class_labels))
mcc = matthews_corrcoef(np.array(y_test), np.array(y_pred))
print('\nMatthews Correlation Coefficient\n')
print(mcc)
cf['fake-real'] = confusion_matrix(np.array(y_test), np.array(y_pred))

Testing the model trained on the fake data on the external test set

In [None]:
data_dir = os.path.abspath(os.path.join(current_directory,'..','Test'))
dataset = ImageFolder(root= data_dir,transform= transform)
class_labels = dataset.classes
num_classes = len(class_labels)
y_test,y_pred = do_prediction(model,dataset,batch_size,device)
print('\nClassification Report\n')
print(classification_report(np.array(y_test), np.array(y_pred), target_names= class_labels))
mcc = matthews_corrcoef(np.array(y_test), np.array(y_pred))
print('\nMatthews Correlation Coefficient\n')
print(mcc)
cf['fake-real_external'] = confusion_matrix(np.array(y_test), np.array(y_pred))

Visualizing the results

In [None]:
font_properties = {'fontname': 'sans-serif', 'fontsize': 12}
titles = ['Experiment 2.1', 'Experiment 2.2','Experiment 2.3','Experiment 2.4']
sup_ttl = 'Classifier Results'
sup_font_properties = {'fontname': 'sans-serif', 'fontsize': 14}
annot_properties={'fontsize': 8,'fontname': 'sans-serif'}
color = 'Blues'


num_figures = len(cf.keys())
figsz = ((num_figures)*3.54,3.54)
width_ratios = []
for i in range(num_figures):
    width_ratios.append(1)
fig,axes = plt.subplots(1,num_figures, gridspec_kw={'width_ratios':width_ratios},figsize = figsz,dpi = 600);
fig.suptitle(sup_ttl,**sup_font_properties)


for num,key in enumerate(cf.keys()):
    ttl = titles[num]
    conf_matrix = cf[key]
    df_conf_matrix = pd.DataFrame(conf_matrix,index= [i for i in class_labels],columns= [i for i in class_labels])
    if num == 0:
        g = sn.heatmap(df_conf_matrix,square=True,cmap= color ,annot=True,fmt=".0f",annot_kws=annot_properties,cbar=False,ax=axes[num]);
        g.set_ylabel('actual',**font_properties);
    else:
        g = sn.heatmap(df_conf_matrix,square=True, cmap = color,fmt=".0f",annot=True,annot_kws=annot_properties,cbar=False,ax=axes[num]);
        g.set_yticks([])
    g.set_xlabel('predicted',**font_properties);
    g.set_title(ttl,**font_properties);

for ax in axes:
    ax.tick_params(axis='both', labelsize=10,width = 1);
plt.tight_layout()
plt.savefig(os.path.join(os.getcwd(),'classifier_results.png'),dpi=600)
plt.savefig(os.path.join(os.getcwd(),'classifier_results.svg'))