In [1]:
import wandb

wandb.login()

import sys

sys.path.append('/home/malt/Documents/Notebooks')

In [2]:
from models import *
from train import *
from generate_data import *

import numpy as np
import pandas as pd
from tqdm import tqdm as tqdm
from matplotlib import pyplot as plt 
%matplotlib inline




import torch
import torchvision
from torch.utils.data import Dataset,DataLoader

import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
import torchvision.transforms as transforms



torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [3]:
def Create_Mosaic_data(desired_num,m,foreground_label,background_label,foreground_data,background_data,dataset="None"):
    
    if dataset =="training":
        n_bg = 35000
        n_fg = 15000
    elif dataset == "test":
        n_bg = 7000
        n_fg = 3000
        
    mosaic_data =[]      # list of mosaic images, each mosaic image is saved as list of 9 images
    fore_idx =[]                   # list of indexes at which foreground image is present in a mosaic image i.e from 0 to 9               
    mosaic_label=[]                # label of mosaic image = foreground class present in that mosaic
    list_set_labels = [] 
    for i in tqdm(range(desired_num)):
        set_idx = set()
        np.random.seed(i)
        bg_idx = np.random.randint(0,n_bg,m-1)
        set_idx = set(background_label[bg_idx].tolist())
        fg_idx = np.random.randint(0,n_fg)
        set_idx.add(foreground_label[fg_idx].item())
        fg = np.random.randint(0,m)
        fore_idx.append(fg)
        image_list,label = create_mosaic_img(foreground_data,background_data,foreground_label,bg_idx,fg_idx,fg,m)
        mosaic_data.append(image_list)
        mosaic_label.append(label)
        list_set_labels.append(set_idx)
    print("Mosaic Data Created")
    return mosaic_data,mosaic_label,fore_idx

In [4]:
fg1, fg2, fg3 = 0,1,2
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)


trainloader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=10, shuffle=False)


classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

foreground_classes = {'plane', 'car', 'bird'}

background_classes = {'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}


dataiter = iter(trainloader)
train_bg_data=[]
train_bg_label=[]
train_fg_data=[]
train_fg_label=[]
batch_size=10

for i in tqdm(range(5000)):   #5000*batch_size = 50000 data points
    images, labels = dataiter.next()
    for j in range(batch_size):
        if(classes[labels[j]] in background_classes):
            img = images[j].tolist()
            train_bg_data.append(img)
            train_bg_label.append(labels[j])
        else:
            img = images[j].tolist()
            train_fg_data.append(img)
            train_fg_label.append(labels[j])
            
train_fg_data = torch.tensor(train_fg_data)
train_fg_label = torch.tensor(train_fg_label)
train_bg_data = torch.tensor(train_bg_data)
train_bg_label = torch.tensor(train_bg_label)
print("Train Foreground Background Data created")


m = 9
desired_num = 40000

train_mosaic_data,train_mosaic_label,train_fore_idx = Create_Mosaic_data(desired_num,m,train_fg_label,
                                     train_bg_label,train_fg_data,train_bg_data,"training")

In [5]:
class MosaicDataset(Dataset):
  """MosaicDataset dataset."""

  def __init__(self, mosaic_list_of_images, mosaic_label, fore_idx):
    """
      Args:
        csv_file (string): Path to the csv file with annotations.
        root_dir (string): Directory with all the images.
        transform (callable, optional): Optional transform to be applied
            on a sample.
    """
    self.mosaic = mosaic_list_of_images
    self.label = mosaic_label
    self.fore_idx = fore_idx

  def __len__(self):
    return len(self.label)

  def __getitem__(self, idx):
    return self.mosaic[idx] , self.label[idx], self.fore_idx[idx]

In [6]:
batch = 256
tr = 30000
msd = MosaicDataset(train_mosaic_data[0:tr], train_mosaic_label[0:tr] , train_fore_idx[0:tr])
train_loader = DataLoader( msd,batch_size= batch ,shuffle=False)

batch = 256
msd1 = MosaicDataset(train_mosaic_data[tr:], train_mosaic_label[tr:] , train_fore_idx[tr:])
test_loader = DataLoader( msd1,batch_size= batch ,shuffle=False)

In [7]:
def print_analysis(data_loader,focus,classification,dataset="None"):
    ftpt_1,ffpt_1,ftpf_1,ffpf_1,accuracy_1 = evaluation_method_1(data_loader,focus,classification)
    ftpt_2,ffpt_2,ftpf_2,ffpf_2,accuracy_2 = evaluation_method_2(data_loader,focus,classification)
    ftpt_3,ffpt_3,ftpf_3,ffpf_3,accuracy_3 = evaluation_method_3(data_loader,focus,classification)
    
    print(str(dataset)+"_Evaluation Method 1")
    print("*"*60)
    print("FTPT",ftpt_1)
    print("FFPT",ffpt_1)
    print("FTPF",ftpf_1)
    print("FFPF",ffpf_1)
    print("Accuracy",accuracy_1)
    
    print(str(dataset)+"_Evaluation Method 2")
    print("*"*60)
    print("FTPT",ftpt_2)
    print("FFPT",ffpt_2)
    print("FTPF",ftpf_2)
    print("FFPF",ffpf_2)
    print("Accuracy",accuracy_2)
    
    print(str(dataset)+"_Evaluation Method 3")
    print("*"*60)
    print("FTPT",ftpt_3)
    print("FFPT",ffpt_3)
    print("FTPF",ftpf_3)
    print("FFPF",ffpf_3)
    print("Accuracy",accuracy_3)

In [8]:
class Focus_6(nn.Module):
  def __init__(self):
    super(Focus_6, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=0,bias=False)
    self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=0,bias=False)
    self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=0,bias=False)
    self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=0,bias=False)
    self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=0,bias=False)
    self.conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1,bias=False)
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.batch_norm1 = nn.BatchNorm2d(32,track_running_stats=False)
    self.batch_norm2 = nn.BatchNorm2d(64,track_running_stats=False)
    self.batch_norm3 = nn.BatchNorm2d(256,track_running_stats=False)
    self.dropout1 = nn.Dropout2d(p=0.05)
    self.dropout2 = nn.Dropout2d(p=0.1)
    self.fc1 = nn.Linear(256,64,bias=False)
    self.fc2 = nn.Linear(64, 32,bias=False)
    self.fc3 = nn.Linear(32, 10,bias=False)
    self.fc4 = nn.Linear(10, 1,bias=False)

    torch.nn.init.xavier_normal_(self.conv1.weight)
    torch.nn.init.xavier_normal_(self.conv2.weight)
    torch.nn.init.xavier_normal_(self.conv3.weight)
    torch.nn.init.xavier_normal_(self.conv4.weight)
    torch.nn.init.xavier_normal_(self.conv5.weight)
    torch.nn.init.xavier_normal_(self.conv6.weight)

    torch.nn.init.xavier_normal_(self.fc1.weight)
    torch.nn.init.xavier_normal_(self.fc2.weight)
    torch.nn.init.xavier_normal_(self.fc3.weight)
    torch.nn.init.xavier_normal_(self.fc4.weight)

  def forward(self,z):  #y is avg image #z batch of list of 9 images
    batch = z.size(0)
    patches = z.size(1)
    z  =torch.reshape(z,(batch*patches,3,32,32))
    alpha,features = self.helper(z)
    
    alpha = torch.reshape(alpha,(batch,patches))
    #features = torch.reshape(features,(batch,patches,features.shape[1],features.shape[2],features.shape[3]))
    return alpha,features #alpha, log_alpha, avg_data
    
  def helper(self, x):
    #x1 = x
    #x1 =x
    x = self.conv1(x)
    x = F.relu(self.batch_norm1(x))

    x = (F.relu(self.conv2(x)))
    x = self.pool(x)
    
    x = self.conv3(x)
    x = F.relu(self.batch_norm2(x))

    x = (F.relu(self.conv4(x)))
    x = self.pool(x)
    x = self.dropout1(x)

    x = self.conv5(x)
    
    x = F.relu(self.batch_norm3(x))

    x = self.conv6(x)
    x1 = F.tanh(x)
    x = F.relu(x)
    x = self.pool(x)

    x = x.view(x.size(0), -1)

    x = self.dropout2(x)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.dropout2(x)
    x = F.relu(self.fc3(x))
    x = self.fc4(x)
    x = x[:,0] 
    #print(x.shape)
    return x,x1

In [9]:
class Classification_6(nn.Module):
  def __init__(self):
    super(Classification_6, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
    self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
    self.conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
    self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
    self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2,padding=1)
    self.batch_norm1 = nn.BatchNorm2d(128,track_running_stats=False)
    self.batch_norm2 = nn.BatchNorm2d(256,track_running_stats=False)
    self.batch_norm3 = nn.BatchNorm2d(512,track_running_stats=False)
    self.dropout1 = nn.Dropout2d(p=0.05)
    self.dropout2 = nn.Dropout2d(p=0.1)
    self.global_average_pooling = nn.AvgPool2d(kernel_size=2)
    self.fc1 = nn.Linear(512,128)
    # self.fc2 = nn.Linear(128, 64)
    # self.fc3 = nn.Linear(64, 10)
    self.fc2 = nn.Linear(128, 3)

    torch.nn.init.xavier_normal_(self.conv1.weight)
    torch.nn.init.xavier_normal_(self.conv2.weight)
    torch.nn.init.xavier_normal_(self.conv3.weight)
    torch.nn.init.xavier_normal_(self.conv4.weight)
    torch.nn.init.xavier_normal_(self.conv5.weight)
    torch.nn.init.xavier_normal_(self.conv6.weight)

    torch.nn.init.zeros_(self.conv1.bias)
    torch.nn.init.zeros_(self.conv2.bias)
    torch.nn.init.zeros_(self.conv3.bias)
    torch.nn.init.zeros_(self.conv4.bias)
    torch.nn.init.zeros_(self.conv5.bias)
    torch.nn.init.zeros_(self.conv6.bias)


    torch.nn.init.xavier_normal_(self.fc1.weight)
    torch.nn.init.xavier_normal_(self.fc2.weight)
    torch.nn.init.zeros_(self.fc1.bias)
    torch.nn.init.zeros_(self.fc2.bias)


  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(self.batch_norm1(x))

    x = (F.relu(self.conv2(x)))
    x = self.pool(x)
    
    x = self.conv3(x)
    x = F.relu(self.batch_norm2(x))

    x = (F.relu(self.conv4(x)))
    x = self.pool(x)
    x = self.dropout1(x)

    x = self.conv5(x)
    x = F.relu(self.batch_norm3(x))

    x = (F.relu(self.conv6(x)))
    x = self.pool(x)
    #print(x.shape)
    x = self.global_average_pooling(x)
    x = x.squeeze()
    #x = x.view(x.size(0), -1)
    #print(x.shape)
    x = self.dropout2(x)
    x = F.relu(self.fc1(x))
    #x = F.relu(self.fc2(x))
    #x = self.dropout2(x)
    #x = F.relu(self.fc3(x))
    x = self.fc2(x)
    return x

In [10]:
def train_model1(data,focus,classification,focus_optimizer,classification_optimizer,Criterion):
    images,labels,fore_idx = data
    images = images.float()
    batch = images.size(0)
    patches = images.size(1)
    images = images.float()
    images,labels = images.to(device),labels.to(device)
            
    focus_optimizer.zero_grad()
    classification_optimizer.zero_grad()
    alpha,features = focus(images)
    #print("Flag 1",alpha.shape,features.shape)
    alphas = torch.softmax(alpha,dim=1)

    outputs = classification(features)
    loss = my_cross_entropy(outputs,labels,alphas,Criterion)
            
    loss.backward()
    focus_optimizer.step()
    classification_optimizer.step()
        
    return focus,classification,focus_optimizer,classification_optimizer

In [11]:
def calculate_metrics(focus,classification,dataloader,dataset="train"):
    focus.eval()
    classification.eval()
    alphas = []
    pred = []
    fidices = []
    with torch.no_grad():
        for i, data in enumerate(dataloader, 0):
            inputs, labels,fidx = data
            inputs, labels = inputs.to("cuda"),labels.to("cuda")
            alpha, avg_images = focus(inputs)
            outputs = classification(avg_images)
            alpha = torch.softmax(alpha,dim=1)
            alphas.append(alpha.cpu().numpy())

        alphas = np.concatenate(alphas,axis=0)


        # value>0.01 here sum over all data points is returned to take average divide by number of data points
        sparsity_val = np.sum(np.sum(alphas>0.01,axis=1))


        # simplex distance here sum over all data points is returned to take average divide by number of data points
        argmax_index = np.argmax(alphas,axis=1)
        simplex_pt = np.zeros(alphas.shape)
        simplex_pt[np.arange(argmax_index.size),argmax_index] = 1

        shortest_distance_simplex = np.sum(np.sqrt(np.sum((alphas-simplex_pt)**2,axis=1))) 

        # entropy here sum over all data points is returned to take average divide by number of data points
        entropy = np.sum(np.nansum(-alphas*np.log2(alphas),axis=1))
    if dataset == "train":
        val =30000
    else:
        val = 10000
    print("dataset "+dataset)
    print(sparsity_val/val,shortest_distance_simplex/val,entropy/val)

In [12]:
# method 1
def evaluation_method_1(dataloader,focus,classification):
    """
    returns \sigma_k(g(x_j*)) j* is the argmax_j(\sigma_j(XU))
    """
    predicted_indexes = []
    foreground_index_list = []
    prediction_list = []
    labels_list = []
    focus.eval()
    classification.eval()
    with torch.no_grad():
        for j,data in enumerate(dataloader):
            images,labels,foreground_index = data
            images = images.float()
            images = images.to(device)
            foreground_index_list.append(foreground_index.numpy())
            labels_list.append(labels.numpy())
            batch = images.size(0)
            scores, features = focus(images)
            
            if len(scores.shape)>2:
                indexes = torch.argmax(F.softmax(scores,dim=1),dim=1).cpu().numpy()[:,0]
            else:
                indexes = torch.argmax(F.softmax(scores,dim=1),dim=1).cpu().numpy()
            predicted_indexes.append(indexes)
            
            features = features.reshape(batch,patches,256,3,3)
            features = features[np.arange(batch),indexes,:]
            #print(features.shape)
            outputs = F.softmax(classification(features),dim=1)
            prediction = torch.argmax(outputs,dim=1)
            prediction_list.append(prediction.cpu().numpy())

    predicted_indexes = np.concatenate(predicted_indexes,axis=0)
    foreground_index_list = np.concatenate(foreground_index_list,axis=0)
    prediction_list = np.concatenate(prediction_list,axis=0)
    labels_list = np.concatenate(labels_list,axis=0)
    
    #print(predicted_indexes.shape,foreground_index_list.shape)

    ftpt = (np.sum(np.logical_and(predicted_indexes == foreground_index_list,
                                 prediction_list == labels_list),axis=0).item()/len(foreground_index_list))*100
    
    ffpt  = (np.sum(np.logical_and(predicted_indexes != foreground_index_list,
                                 prediction_list == labels_list),axis=0).item()/len(foreground_index_list))*100
    
    ftpf  = (np.sum(np.logical_and(predicted_indexes == foreground_index_list,
                                 prediction_list != labels_list),axis=0).item()/len(foreground_index_list))*100
    
    ffpf  = (np.sum(np.logical_and(predicted_indexes != foreground_index_list,
                                 prediction_list != labels_list),axis=0).item()/len(foreground_index_list))*100
#     focus_true = (np.sum(predicted_indexes == foreground_index_list,axis=0).item()/
#                         len(foreground_index_list))*100
    accuracy = (np.sum(prediction_list == labels_list,axis=0)/len(labels_list) )*100
    
    return ftpt,ffpt,ftpf,ffpf,accuracy

In [13]:
def evaluation_method_2(dataloader,focus,classification):
    """
    returns \sum_j(\alpha_j * \sigma_k(g(x_j)) 
    """
    predicted_indexes = []
    foreground_index_list = []
    prediction_list = []
    labels_list = []
    focus.eval()
    classification.eval()
    with torch.no_grad():
        for j,data in enumerate(dataloader):
            images,labels,foreground_index = data
            images = images.float()
            images = images.to(device)
            batch = images.size(0)
            patches = images.size(1)
            foreground_index_list.append(foreground_index.numpy())
            labels_list.append(labels.numpy())
            batch = images.size(0)
            alpha,features = focus(images)
            focus_outputs = F.softmax(alpha,dim=1)
            if len(focus_outputs.shape)>2:
                focus_outputs = focus_outputs[:,:,0]
            indexes = torch.argmax(focus_outputs,dim=1).cpu().numpy()
            predicted_indexes.append(indexes)
            
            if len(images.shape)>3:
                features = features.reshape(batch*patches,256,3,3)
 
            classification_outputs = F.softmax(classification(features),dim=1)
            n_classes = classification_outputs.size(1)
            classification_outputs = classification_outputs.reshape(batch,patches,n_classes)

            #print(classification_outputs.shape,focus_outputs.shape)
            if len(images.shape)>3:
                focus_outputs = focus_outputs[:,:,None]
            else:
                focus_outputs = focus_outputs[:,:,None]
            prediction = torch.argmax(torch.sum(focus_outputs*classification_outputs,dim=1),dim=1)
            
           
            prediction_list.append(prediction.cpu().numpy())

    predicted_indexes = np.concatenate(predicted_indexes,axis=0)
    foreground_index_list = np.concatenate(foreground_index_list,axis=0)
    prediction_list = np.concatenate(prediction_list,axis=0)
    labels_list = np.concatenate(labels_list,axis=0)
    print(prediction_list.shape)

    ftpt = (np.sum(np.logical_and(predicted_indexes == foreground_index_list,
                                 prediction_list == labels_list),axis=0).item()/len(foreground_index_list))*100
    
    ffpt  = (np.sum(np.logical_and(predicted_indexes != foreground_index_list,
                                 prediction_list == labels_list),axis=0).item()/len(foreground_index_list))*100
    
    ftpf  = (np.sum(np.logical_and(predicted_indexes == foreground_index_list,
                                 prediction_list != labels_list),axis=0).item()/len(foreground_index_list))*100
    
    ffpf  = (np.sum(np.logical_and(predicted_indexes != foreground_index_list,
                                 prediction_list != labels_list),axis=0).item()/len(foreground_index_list))*100
#     focus_true = (np.sum(predicted_indexes == foreground_index_list,axis=0).item()/
#                         len(foreground_index_list))*100
    accuracy = (np.sum(prediction_list == labels_list,axis=0)/len(labels_list) )*100
    
    return ftpt,ffpt,ftpf,ffpf,accuracy

In [14]:
# method 3
def evaluation_method_3(dataloader,focus,classification):
    """
    returns \sum_j( \sigma_k(\alpha_j * g(x_j)) 
    """
    
    predicted_indexes = []
    foreground_index_list = []
    prediction_list = []
    labels_list = []
    focus.eval()
    classification.eval()
    with torch.no_grad():
        for j,data in enumerate(dataloader):
            images,labels,foreground_index = data
            images = images.float()
            images = images.to(device)
            foreground_index_list.append(foreground_index.numpy())
            labels_list.append(labels.numpy())
            batch = images.size(0)
            scores,features = focus(images)
            alphas = F.softmax(scores,dim=1)
            if len(scores.shape)>2:
                indexes = torch.argmax(F.softmax(scores,dim=1),dim=1).cpu().numpy()[:,0]
            else:
                indexes = torch.argmax(F.softmax(scores,dim=1),dim=1).cpu().numpy()
            predicted_indexes.append(indexes)
            features = features.reshape(batch,patches,256,3,3)
            if len(images.shape)>3:
                features = torch.sum(alphas[:,:,None,None,None]*features,dim=1)
            else:
                images = torch.sum(alphas*images,dim=1)
            
            outputs = F.softmax(classification(features),dim=1)
            prediction = torch.argmax(outputs,dim=1)
            prediction_list.append(prediction.cpu().numpy())
#     print(len(predicted_indexes))
    predicted_indexes = np.concatenate(predicted_indexes,axis=0)
    foreground_index_list = np.concatenate(foreground_index_list,axis=0)
    prediction_list = np.concatenate(prediction_list,axis=0)
    labels_list = np.concatenate(labels_list,axis=0)


    ftpt = (np.sum(np.logical_and(predicted_indexes == foreground_index_list,
                                 prediction_list == labels_list),axis=0).item()/len(foreground_index_list))*100
    
    ffpt  = (np.sum(np.logical_and(predicted_indexes != foreground_index_list,
                                 prediction_list == labels_list),axis=0).item()/len(foreground_index_list))*100
    
    ftpf  = (np.sum(np.logical_and(predicted_indexes == foreground_index_list,
                                 prediction_list != labels_list),axis=0).item()/len(foreground_index_list))*100
    
    ffpf  = (np.sum(np.logical_and(predicted_indexes != foreground_index_list,
                                 prediction_list != labels_list),axis=0).item()/len(foreground_index_list))*100
#     focus_true = (np.sum(predicted_indexes == foreground_index_list,axis=0).item()/
#                         len(foreground_index_list))*100
    accuracy = (np.sum(prediction_list == labels_list,axis=0)/len(labels_list) )*100
    
    return ftpt,ffpt,ftpf,ffpf,accuracy

In [15]:
nos_epochs = 50

learning_rates = [0.0001] #0.0025

run =  wandb.init(project="Interpretability_CIFAR_10_Experiments",
               name ="HA_sixth_layer_averaging",config = {
                   "learning rate ":learning_rates,
                   "epochs":50
               },save_code=True)


focus_lr_plots = []
classification_lr_plots = []
n_seeds = [0,1,2]
for seed in n_seeds:
    for run_no in range(len(learning_rates)):
        torch.manual_seed(seed)
        focus = Focus_6()
        focus = focus.to(device)



        torch.manual_seed(seed)
        classification = Classification_6()
        classification = classification.to(device)


        lr = learning_rates[run_no] 

        Criterion = nn.CrossEntropyLoss(reduction="none") #nn.BCELoss(reduction="none")
        focus_optimizer = optim.Adam(focus.parameters(), lr=lr)#,momentum=0.9)
        classification_optimizer = optim.Adam(classification.parameters(),lr=lr)#,momentum=0.9)


        loss_list = []


        for epoch in tqdm(range(nos_epochs)):
            focus.train()
            classification.train()

            epoch_loss = [] 

            for i,data in enumerate(train_loader):
                focus,classification,focus_optimizer,classification_optimizer=train_model1(data,
                                                                                          focus,
                                                                                          classification,
                                                                                          focus_optimizer,
                                                                                          classification_optimizer,
                                                                                          Criterion)

                with torch.no_grad():
                    images,labels,fore_idx = data
                    batch = images.size(0)
                    patches = images.size(1)
                    images,labels = images.to(device),labels.to(device)
                    alpha,features = focus(images)

                    alphas = torch.softmax(alpha,dim=1)

                    features =  features.reshape(batch*patches,256,3,3)
                    outputs = classification(features)
                    loss = my_cross_entropy(outputs,labels,alphas,Criterion)

                epoch_loss.append(loss.item())
            #print('[%d] loss: %.3f' %(epoch+1,np.mean(epoch_loss)))
            wandb.log(
                    {"loss_lr_"+str(learning_rates[run_no]):np.mean(epoch_loss).item(),
                    "Epoch":epoch+1})
            wandb.log({"Epoch":epoch+1})

            loss_list.append(np.mean(epoch_loss))

    print_analysis(train_loader,focus,classification,dataset="training")
    print_analysis(test_loader,focus,classification,dataset="testing")
    calculate_metrics(focus,classification,train_loader,"training")
    calculate_metrics(focus,classification,test_loader,"testing")
    # print_analysis(test_data_loader,focus,classification,dataset="cifar test")    
    print("Finished Training")
    torch.save(focus.state_dict(), 'focus.pth')  
    torch.save(classification.state_dict(), 'classification.pth')  
    artifact = wandb.Artifact('model', type='model')
    artifact.add_file('focus.pth')
    artifact.add_file('classification.pth')
    run.log_artifact(artifact)
wandb.finish()