In [None]:
import os
from pathlib import Path
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
import importlib
import tensorflow as tf
import io
import torchvision.transforms as transforms
import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from pathlib import Path
from model import tcn
import scipy.io
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [None]:
def train(transform, learning_rate = .01, start_epoch = 0):
    num_folds = 10
    epochs = 1
    batch_size = 8
    dropout = .1
    all_heats_gamble = None
    all_heats_no_gamble = None
    for fold in range(num_folds):
        
        trainset = dataset = torchvision.datasets.CIFAR10(root='.', train=True,
                                    download=True, transform = transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle = True)

        testset = dataset = torchvision.datasets.CIFAR10(root='.', train=False,
                                    download=True, transform = transform)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(device)
        net = tcn(3, num_classes = 10)
        
        if start_epoch > 0:
            PATH = 'models/epoch ' + str(start_epoch)
            net.load_state_dict(torch.load(f"{PATH}/epoch " + str(start_epoch -1), weights_only=True, map_location=torch.device('cpu')))

        net.to(device)
        criterion = nn.CrossEntropyLoss()
        criterion.to(device)
        optimizer = optim.SGD(net.parameters(), lr=learning_rate)

        for epoch in range(start_epoch, start_epoch + epochs): 
            running_loss = 0.0
            net.train()
            val_steps = 0
            pbar = tqdm.tqdm(trainloader)
            for inputs, ind_labels in pbar:

                optimizer.zero_grad()
                inputs = torch.FloatTensor(inputs)

                inputs = inputs.to(device)
                outputs, cams = net(inputs)
                ind_labels = ind_labels.to(device)
                loss = criterion(outputs, nn.functional.one_hot(ind_labels, num_classes=10).float())
                loss.backward()
                optimizer.step()
                val_steps += 1
                pbar.set_description("Loss: %f" % loss)
            
            correct_epoch  = test(net, testloader)
            PATH = 'models/epoch ' + str(epoch + 1)
            Path(PATH).mkdir(parents=True, exist_ok=True)
            torch.save(net.state_dict(), f"{PATH}/epoch " + str(epoch))

            checkpoint_data = {
                "epoch": epoch,
                "net_state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            }

            h = heats(net, testloader)
            image = plot_cams(h, 10)


In [None]:
def test(net, testloader):
    net.eval()
    correct_epoch = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    with torch.no_grad():
        for data, label in testloader:
            data = data.to(device)
            output, cams = net(data)
            _, predicted = torch.max(output.data, 1)
            labels  = label
            labels  = labels.to(device)
            correct_epoch += (predicted == labels).sum().item()
    return correct_epoch

In [None]:
def heats(net, testloader):
    net.eval()
    with torch.no_grad():
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        classes = [0,1,2,3,4,5,6,7,8,9]
        heatmap = {classname: [] for classname in classes}
        correct_pred = {classname: 0 for classname in classes}
        count = 0
        for images, labels_temp in testloader:
            #just generate CAMs for the first 100 batches
            count = count + 1
            if count > 100:
                break
            images = images.to(device)
            outputs, cams = net(images, generateCAMs = True)
            _, predicted = torch.max(outputs.data, 1)
            labels_temp  = labels_temp.to(device)
            for batch in range(len(labels_temp)):
                h = cams[batch,:].cpu()
                h = np.array(h)
                h = h.squeeze()
                heat = np.array(h[ labels_temp[batch],:])
                heatmap[classes[labels_temp[batch]]].append(heat)
                if labels_temp[batch] == predicted[batch]:
                    correct_pred[classes[labels_temp[batch]]] += 1
            
        return heatmap

def gen_plot():
    plt.figure()
    plt.plot([1, 2])
    plt.title("test")
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    return buf


# Plot the CAM
def plot_cams(heats, num_classes):
    for i in range(num_classes):
        fig, ax = plt.subplots()
        column_labels_gamble = [0] * len(heats[i][0])
        column_labels_gamble[0] = -1
        column_labels_gamble[-1] = 0
        plot_buf = gen_plot()
        heatmap = sns.heatmap(np.array(heats[i]), xticklabels=len(heats[i][0]) - 1, yticklabels=len(heats[i])//5, cmap=sns.color_palette("Spectral", as_cmap=True), ax=ax, cbar=False)
        figure = heatmap.get_figure()  
        plt.yticks(rotation=0)
        images_dir = 'CAMs'
        
        Path(images_dir).mkdir(parents=True, exist_ok=True)
        plt.ylabel("Trial")
        buf = io.BytesIO()
        figure.savefig(f"{images_dir}/" + str(i)+ ".eps", format = "eps")
        buf.seek(0)
        image = tf.image.decode_png(plot_buf.getvalue(), channels=4)
        image = tf.expand_dims(image,0)
        summary_op = tf.summary.image("plot", image)
        plt.clf()

    return image

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalizes to [-1, 1]
    transforms.Lambda(lambda x: x.view(x.shape[0], -1)) 
])

train(transform)