In [52]:
import torch
import torchvision
import torchvision.transforms as transforms

from datetime import datetime

import os
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import time

import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [53]:
transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.Grayscale(num_output_channels=3),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [54]:
training_set = torchvision.datasets.Caltech101('/home/crueang/Chaks/AIOT/data', transform=transform, download=True)
train_size = int(0.8 * len(training_set))
test_size = len(training_set) - train_size
training_set, validation_set = torch.utils.data.random_split(training_set, [train_size, test_size])

training_loader = torch.utils.data.DataLoader(training_set, batch_size=16, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=16, shuffle=False)

print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

Files already downloaded and verified
Training set has 6941 instances
Validation set has 1736 instances


In [55]:
def plot_graph(history):
    fig, (ax1, ax2) = plt.subplots(1, 2)
    fig.set_figwidth(10)
    fig.suptitle("Train vs Validation")
    ax1.plot(history["train_acc"], label="Train")
    ax1.plot(history["validate_acc"], label="Validation")
    ax1.legend()
    ax1.set_title("Accuracy")

    ax2.plot(history["train_loss"], label="Train")
    ax2.plot(history["validate_loss"], label="Validation")
    ax2.legend()
    ax2.set_title("Loss")
    fig.show()

In [56]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True):
        super(ConvBlock,self).__init__()
        self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size,stride=stride, padding=padding, bias=bias)
        self.batchnorm2d = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self,x):
        return self.relu(self.batchnorm2d(self.conv2d(x)))
    
class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1_pooling):
        super(InceptionBlock,self).__init__()

        self.branch1 = ConvBlock(in_channels,out_1x1,1,1,0)
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels,red_3x3,1,1,0),
            ConvBlock(red_3x3,out_3x3,3,1,1)
        )
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels,red_5x5,1,1,0),
            ConvBlock(red_5x5,out_5x5,5,1,2)
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3,stride=1,padding=1),
            ConvBlock(in_channels,out_1x1_pooling,1,1,0)
        )

    def forward(self,x):
        return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], dim=1)


In [57]:
class InceptionNet(nn.Module):
    def __init__(self):
        super(InceptionNet, self).__init__()
        
        self.conv1 = ConvBlock(3, 64, 7, 2, 3)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv2 = ConvBlock(64, 192, 3, 1, 1)
        
        self.inception1 = InceptionBlock(192, 64, 96, 128, 16, 32, 32)
        self.inception2 = InceptionBlock(256, 128, 128, 192, 32, 96, 64)
        
        self.inception3 = InceptionBlock(480, 192, 96, 208, 16, 48, 64)
        self.inception4 = InceptionBlock(512, 160, 112, 224, 24, 64, 64)
        
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(p=0.4)
        self.fc1 = nn.Linear(25088, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 101)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.maxpool(x)
        
        x = self.inception1(x)
        x = self.inception2(x)
        x = self.maxpool(x)
        
        x = self.inception3(x)
        x = self.inception4(x)
        x = self.maxpool(x)
        
        x = self.dropout(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return x
    
model = InceptionNet().to(device)

In [139]:
class InceptionV3Block(nn.Module):
    def __init__(self, in_channels, out_1x1, out_1x1_pooling, red_1331_1x1, out_1x3_1x1, out_3x1_1x1, red_3x3, red1331_3x3, out1x3_3x3, out3x1_3x3):
        super(InceptionV3Block,self).__init__()
        
        self.branch1 = ConvBlock(in_channels, out_1x1, 1, 1, 0)
        self.branch2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            ConvBlock(in_channels, out_1x1_pooling, 1, 1, 0)
        )
        self.branch3 = ConvBlock(in_channels, red_1331_1x1, 1, 1, 0)
        self.branch3_1 = ConvBlock(red_1331_1x1, out_1x3_1x1, (1, 3), 1, (0, 1))
        self.branch3_2 = ConvBlock(red_1331_1x1, out_3x1_1x1, (3, 1), 1, (1, 0))
        self.branch4 = nn.Sequential(
            ConvBlock(in_channels, red_3x3, 1, 1, 0),
            ConvBlock(red_3x3, red1331_3x3, 3, 1, 1)
        )
        self.branch4_1 = ConvBlock(red1331_3x3, out1x3_3x3, (1, 3), 1, (0, 1))
        self.branch4_2 = ConvBlock(red1331_3x3, out3x1_3x3, (3, 1), 1, (1, 0))
    
    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch3_1 = self.branch3_1(branch3)
        branch3_2 = self.branch3_2(branch3)
        branch4 = self.branch4(x)
        branch4_1 = self.branch4_1(branch4)
        branch4_2 = self.branch4_2(branch4)
        
        return torch.cat([branch1, branch2, branch3_1, branch3_2, branch4_1, branch4_2], 1)

In [146]:
class InceptionV3Net(nn.Module):
    def __init__(self):
        super(InceptionV3Net, self).__init__()
        
        self.conv1 = ConvBlock(3, 64, 7, 2, 3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv2 = ConvBlock(64, 192, 3, 1, 1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.inception1 = InceptionV3Block(192, 64, 64, 128, 32, 32, 64, 128, 32, 32)
        self.inception2 = InceptionV3Block(256, 32, 48, 128, 24, 24, 64, 192, 32, 32)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.inception3 = InceptionV3Block(192, 192, 48, 96, 16, 16, 64, 128, 16, 16)
        self.inception4 = InceptionV3Block(304, 192, 64, 128, 24, 24, 64, 128, 24, 24)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(p=0.4)
        self.fc1 = nn.Linear(17248, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 101)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.maxpool2(x)
        
        x = self.inception1(x)
        x = self.inception2(x)
        x = self.maxpool3(x)
        
        x = self.inception3(x)
        x = self.inception4(x)
        x = self.maxpool4(x)
        
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return x

model = InceptionV3Net().to(device)

In [147]:
summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,472
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         ConvBlock-4         [-1, 64, 112, 112]               0
         MaxPool2d-5           [-1, 64, 56, 56]               0
            Conv2d-6          [-1, 192, 56, 56]         110,784
       BatchNorm2d-7          [-1, 192, 56, 56]             384
              ReLU-8          [-1, 192, 56, 56]               0
         ConvBlock-9          [-1, 192, 56, 56]               0
        MaxPool2d-10          [-1, 192, 28, 28]               0
           Conv2d-11           [-1, 64, 28, 28]          12,352
      BatchNorm2d-12           [-1, 64, 28, 28]             128
             ReLU-13           [-1, 64, 28, 28]               0
        ConvBlock-14           [-1, 64,

In [148]:
EPOCHS = 10
checkpoint_path = '/home/crueang/Chaks/AIOT/5_1_homework/checkpoint/inceptionV3/checkpoint_inceptionV3_pretrained/'
training_logs = {"train_loss": [],  "train_acc": [], "validate_loss": [], "validate_acc": []}

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [60]:
from train import train

In [152]:
train(loss_fn, optimizer, model, training_logs, validation_loader, training_loader, EPOCHS, checkpoint_path=checkpoint_path, device=device)

Epochs 1  train_loss: 1.88616 train_acc: 0.34280 validate_loss: 1.51023 validate_acc: 0.44050 
--------------------------------------------------------------------------------
Epochs 2  train_loss: 1.34675 train_acc: 0.50320 validate_loss: 1.34851 validate_acc: 0.51500 
--------------------------------------------------------------------------------
Epochs 3  train_loss: 1.15161 train_acc: 0.57460 validate_loss: 1.25132 validate_acc: 0.55063 
--------------------------------------------------------------------------------
Epochs 4  train_loss: 1.01342 train_acc: 0.63220 validate_loss: 1.12508 validate_acc: 0.60325 
--------------------------------------------------------------------------------
Epochs 5  train_loss: 0.89668 train_acc: 0.66980 validate_loss: 1.15028 validate_acc: 0.59825 
--------------------------------------------------------------------------------
Epochs 6  train_loss: 0.79438 train_acc: 0.71620 validate_loss: 1.24228 validate_acc: 0.58925 
-------------------------

In [153]:
transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.Grayscale(num_output_channels=3),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         ])

training_set = torchvision.datasets.STL10('/home/crueang/Chaks/AIOT/data', split='train', transform=transform, download=False)
validation_set = torchvision.datasets.STL10('/home/crueang/Chaks/AIOT/data', split='test', transform=transform, download=False)

training_loader = torch.utils.data.DataLoader(training_set, batch_size=16, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=16, shuffle=False)

print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

Training set has 5000 instances
Validation set has 8000 instances


In [154]:
new_model_head = InceptionV3Net().to(device)
checkpoint = torch.load('/home/crueang/Chaks/AIOT/5_1_homework/checkpoint/inceptionV3/checkpoint_inceptionV3_pretrained/best_model.pth', weights_only=True)
new_model_head.load_state_dict(checkpoint, strict=False)
new_model_head.fc3 = nn.Linear(512, 10).to(device)
summary(new_model_head, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,472
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         ConvBlock-4         [-1, 64, 112, 112]               0
         MaxPool2d-5           [-1, 64, 56, 56]               0
            Conv2d-6          [-1, 192, 56, 56]         110,784
       BatchNorm2d-7          [-1, 192, 56, 56]             384
              ReLU-8          [-1, 192, 56, 56]               0
         ConvBlock-9          [-1, 192, 56, 56]               0
        MaxPool2d-10          [-1, 192, 28, 28]               0
           Conv2d-11           [-1, 64, 28, 28]          12,352
      BatchNorm2d-12           [-1, 64, 28, 28]             128
             ReLU-13           [-1, 64, 28, 28]               0
        ConvBlock-14           [-1, 64,

In [155]:
EPOCHS = 5
checkpoint_path = '/home/crueang/Chaks/AIOT/5_1_homework/checkpoint/inceptionV3/checkpoint_inceptionV3_resume/'
training_logs = {"train_loss": [],  "train_acc": [], "validate_loss": [], "validate_acc": []}

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(new_model_head.parameters(), lr=0.001, momentum=0.9)

In [156]:
def load_checkpoint(model, optimizer, training_logs, checkpoint_path=None, device='cpu'):
    epoch_number = 0
    best_vloss = float('inf')
    if checkpoint_path:
        if os.path.exists(checkpoint_path + 'model.pth'):
            model.load_state_dict(torch.load(checkpoint_path + 'model.pth', weights_only=True, map_location=device))

        if os.path.exists(checkpoint_path + 'opt.pth'):
            optimizer.load_state_dict(torch.load(checkpoint_path + 'opt.pth', weights_only=True, map_location=device))

        if os.path.exists(checkpoint_path + 'training_logs.pth'):
            training_logs = torch.load(checkpoint_path + 'training_logs.pth', weights_only=True)
            epoch_number = len(training_logs['train_loss'])
            best_vloss = min(training_logs['validate_loss'])
    
    for i in range(epoch_number):
        print(f"Epochs {i+1}".ljust(10), end='')
        for key in training_logs.keys():
            print(f"{key}: {training_logs[key][i]:.5f}", end=" ")
        print()
        print("-"*80)

    return training_logs, best_vloss, epoch_number

def train(loss_fn, optimizer, model, training_logs, validation_loader, training_loader, EPOCHS, checkpoint_path=None, device='cpu'):
    if checkpoint_path:
        if not os.path.exists(checkpoint_path):
            os.mkdir(checkpoint_path)
        training_logs, best_vloss, epoch_number = load_checkpoint(model, optimizer, training_logs, checkpoint_path, device)
    
    t_0_accelerated = time.time()
    for epoch in range(epoch_number, EPOCHS):
        train_loss, train_correct = 0, 0
        model.train(True)

        for i, data in enumerate(training_loader):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()

            outputs = model(inputs)

            loss = loss_fn(outputs, labels)
            loss.backward()

            optimizer.step()

            train_loss += loss.item()
            train_correct += (outputs.argmax(1) == labels).float().sum().item()

        training_logs["train_loss"].append(train_loss / len(training_loader))
        training_logs["train_acc"].append(train_correct / len(training_loader.dataset))

        model.eval()
        valid_loss, valid_correct = 0, 0
        with torch.no_grad():
            for i, vdata in enumerate(validation_loader):
                vinputs, vlabels = vdata[0].to(device), vdata[1].to(device)
                voutputs = model(vinputs)

                valid_loss += loss_fn(voutputs, vlabels).item()
                valid_correct += (voutputs.argmax(1) == vlabels).float().sum().item()

            training_logs["validate_loss"].append(valid_loss / len(validation_loader))
            training_logs["validate_acc"].append(valid_correct / len(validation_loader.dataset))

        print(f"Epochs {epoch+1}".ljust(10), end='')
        for key in training_logs.keys():
            print(f"{key}: {training_logs[key][-1]:.5f}", end=" ")
        print()
        print("-"*80)

        if checkpoint_path:
            torch.save(model.state_dict(), checkpoint_path + "model.pth")
            torch.save(optimizer.state_dict(), checkpoint_path + "opt.pth")
            torch.save(training_logs, checkpoint_path + 'training_logs.pth')
            if best_vloss > valid_loss:
               torch.save(model.state_dict(), checkpoint_path + "best_model.pth")
               best_vloss = valid_loss

    t_end_accelerated = time.time()-t_0_accelerated
    print(f"Time consumption for accelerated CUDA training (device:{device}): {t_end_accelerated} sec")

In [157]:
train(loss_fn, optimizer, new_model_head, training_logs, validation_loader, training_loader, EPOCHS, checkpoint_path=checkpoint_path, device=device)

Epochs 1  train_loss: 1.30982 train_acc: 0.51940 validate_loss: 1.35321 validate_acc: 0.53138 
--------------------------------------------------------------------------------
Epochs 2  train_loss: 1.02195 train_acc: 0.63200 validate_loss: 1.10962 validate_acc: 0.59550 
--------------------------------------------------------------------------------
Epochs 3  train_loss: 0.89961 train_acc: 0.67220 validate_loss: 1.46929 validate_acc: 0.50750 
--------------------------------------------------------------------------------
Epochs 4  train_loss: 0.81768 train_acc: 0.69940 validate_loss: 1.26175 validate_acc: 0.59075 
--------------------------------------------------------------------------------
Epochs 5  train_loss: 0.70224 train_acc: 0.74600 validate_loss: 1.08161 validate_acc: 0.62662 
--------------------------------------------------------------------------------
Time consumption for accelerated CUDA training (device:cuda): 148.68760561943054 sec


In [158]:
finetune = InceptionV3Net().to(device)
checkpoint = torch.load('/home/crueang/Chaks/AIOT/5_1_homework/checkpoint/inceptionV3/checkpoint_inceptionV3_pretrained/best_model.pth', weights_only=True)
finetune.load_state_dict(checkpoint, strict=False)
for param in finetune.parameters():
    param.requires_grad = False
finetune.fc3 = nn.Linear(512, 10).to(device)
summary(finetune, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,472
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         ConvBlock-4         [-1, 64, 112, 112]               0
         MaxPool2d-5           [-1, 64, 56, 56]               0
            Conv2d-6          [-1, 192, 56, 56]         110,784
       BatchNorm2d-7          [-1, 192, 56, 56]             384
              ReLU-8          [-1, 192, 56, 56]               0
         ConvBlock-9          [-1, 192, 56, 56]               0
        MaxPool2d-10          [-1, 192, 28, 28]               0
           Conv2d-11           [-1, 64, 28, 28]          12,352
      BatchNorm2d-12           [-1, 64, 28, 28]             128
             ReLU-13           [-1, 64, 28, 28]               0
        ConvBlock-14           [-1, 64,

In [159]:
EPOCHS = 5
checkpoint_path = '/home/crueang/Chaks/AIOT/5_1_homework/checkpoint/inceptionV3/checkpoint_inceptionV3_finetune/'
training_logs = {"train_loss": [],  "train_acc": [], "validate_loss": [], "validate_acc": []}

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(finetune.parameters(), lr=0.001, momentum=0.9)

In [160]:
train(loss_fn, optimizer, finetune, training_logs, validation_loader, training_loader, EPOCHS, checkpoint_path=checkpoint_path, device=device)

Epochs 1  train_loss: 1.21402 train_acc: 0.55760 validate_loss: 1.17065 validate_acc: 0.57100 
--------------------------------------------------------------------------------
Epochs 2  train_loss: 1.05658 train_acc: 0.61400 validate_loss: 1.15870 validate_acc: 0.58013 
--------------------------------------------------------------------------------
Epochs 3  train_loss: 1.03112 train_acc: 0.62580 validate_loss: 1.16024 validate_acc: 0.57725 
--------------------------------------------------------------------------------
Epochs 4  train_loss: 1.00439 train_acc: 0.63940 validate_loss: 1.13167 validate_acc: 0.59637 
--------------------------------------------------------------------------------
Epochs 5  train_loss: 1.00078 train_acc: 0.63540 validate_loss: 1.16046 validate_acc: 0.58000 
--------------------------------------------------------------------------------
Time consumption for accelerated CUDA training (device:cuda): 118.77279782295227 sec


In [161]:
def get_labels_predictions(model, dataloader, device):
    model.eval()  # Set the model to evaluation mode
    all_labels = []
    all_predictions = []

    with torch.no_grad():  # Disable gradient computation
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)

            # Get the predicted class
            _, predictions = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    return all_labels, all_predictions

In [1]:
# https://github.com/fyse-nassar/Malware-Family-Classification/blob/master/Malware%20Opcode%20Ngrams%20Generator.ipynb
# https://scikit-learn.org/0.18/auto_examples/model_selection/plot_confusion_matrix.html

import itertools
import matplotlib.pyplot as plt
import numpy as np

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    given a sklearn confusion matrix (cm), make a nice plot

    Arguments
    ---------
    cm:           confusion matrix from sklearn.metrics.confusion_matrix

    target_names: given classification classes such as [0, 1, 2]
                  the class names, for example: ['high', 'medium', 'low']

    title:        the text to display at the top of the matrix

    cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
                  see http://matplotlib.org/examples/color/colormaps_reference.html
                  plt.get_cmap('jet') or plt.cm.Blues

    Normalization can be applied by setting `normalize=True`.
    normalize:    If False, plot the raw numbers
                  If True, plot the proportions

    Usage
    -----
    plot_confusion_matrix(cm           = cm,                  # confusion matrix created by
                                                              # sklearn.metrics.confusion_matrix
                          normalize    = True,                # show proportions
                          target_names = y_labels_vals,       # list of names of the classes
                          title        = best_estimator_name) # title of graph

    Citiation
    ---------
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    """

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    #plt.xticks(tick_marks, classes, rotation=45)    #office-31
    plt.xticks(tick_marks, classes, rotation=20, fontsize=12)    #office-home
    plt.yticks(tick_marks, classes, fontsize=12)

    if normalize==True:
        cm = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis]+1)
        #print("Normalized confusion matrix")
    #else:
        #print('Confusion matrix, without normalization')

    # print(cm)

    formated = '.2f' if normalize==True else 'd'
    #---manual---
    thresh = cm.max() / 2.
    #thresh > 0.5
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], formated),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    #---sns---
    #df_cm = pd.DataFrame(cm, classes, classes)
    #sns.heatmap(df_cm, annot=True, fmt=formated, cmap=cmap)

    plt.gcf().set_size_inches(8, 6)
    plt.ylabel('Ground Truth')
    plt.xlabel('Prediction')
    plt.margins(2,2)
    plt.tight_layout()



"""
  Args: Collect feature
  Ref: https://github.com/zhjscut/Bridging_UDA_SSL/blob/e0be6742f1203bb983261e3e1e57d34e1e03299d/common/utils/analysis/__init__.py#L7
"""
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import tqdm
import os.path as osp


def collect_feature(data_loader: DataLoader, feature_extractor: nn.Module,
                                   device: torch.device, max_num_features=None) -> torch.Tensor:
    """
    Fetch data from `data_loader`, and then use `feature_extractor` to collect features
    Args:
        data_loader (torch.utils.data.DataLoader): Data loader.
        feature_extractor (torch.nn.Module): A feature extractor.
        device (torch.device)
        max_num_features (int): The max number of features to return
    Returns:
        Features in shape (min(len(data_loader), max_num_features), :math:`|\mathcal{F}|`).
    """
    feature_extractor.eval()
    all_features = []
    all_labels = []
    with torch.no_grad():
        for i, (images, target) in enumerate(tqdm.tqdm(data_loader)):
            images = images.to(device)
            features = feature_extractor(images)
            if isinstance(features, tuple):
                # Check if it's a tuple (common when using certain pre-trained models)
                # You may want to select the feature tensor you need from the tuple
                # For example, if the feature tensor is the first element of the tuple:
                feature_tensor = features[0]
                feature_tensor = feature_tensor.to(device)  # Move the tensor to CPU
            else:
                feature_tensor = features.to(device)  # Move the tensor to CPU

            all_features.append(feature_tensor)
            all_labels.append(target)

    return torch.cat(all_features, dim=0),\
        torch.cat(all_labels, dim=0) # Concatenate the list of feature tensors

    #        all_features.append(features)
    #        if max_num_features is not None and i >= max_num_features:
    #            break
    #return torch.cat(all_features, dim=0)


# ref https://github.com/zhjscut/Bridging_UDA_SSL/blob/e0be6742f1203bb983261e3e1e57d34e1e03299d/common/utils/analysis/__init__.py#L7
import torch
import matplotlib

matplotlib.use('Agg')
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as col


def visualize(source_feature: torch.Tensor, target_feature: torch.Tensor,
              filename: str, source_color='r', target_color='b'):
    """
    Visualize features from different domains using t-SNE.
    Args:
        source_feature (tensor): features from source domain in shape :math:`(minibatch, F)`
        target_feature (tensor): features from target domain in shape :math:`(minibatch, F)`
        filename (str): the file name to save t-SNE
        source_color (str): the color of the source features. Default: 'r'
        target_color (str): the color of the target features. Default: 'b'
    """
    source_feature = source_feature.cpu().numpy()
    target_feature = target_feature.cpu().numpy()
    features = np.concatenate([source_feature, target_feature], axis=0)

    # map features to 2-d using TSNE
    X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features)

    # domain labels, 1 represents source while 0 represents target
    domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature))))

    # visualize using matplotlib
    plt.figure(figsize=(10, 10))
    plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap([source_color, target_color]), s=20)  #default: s=2
    plt.savefig(filename)

import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.colors as col

"""
  Arg: t-SNE for class clustering visualization
"""

def visualize_class_n_domain(source_feature: torch.Tensor, target_feature: torch.Tensor, source_labels: torch.Tensor, target_labels: torch.Tensor, filename: str, source_color='r', target_color='b'):
    """
    Visualize features from different domains using t-SNE.
    Args:
        source_feature (tensor): features from source domain in shape :math:`(minibatch, F)`
        target_feature (tensor): features from target domain in shape :math:`(minibatch, F)`
        source_labels (tensor): class labels for source domain features
        target_labels (tensor): class labels for target domain features
        filename (str): the file name to save t-SNE
        source_color (str): the color of the source features. Default: 'r'
        target_color (str): the color of the target features. Default: 'b'
    """
    source_feature = source_feature.cpu().numpy()
    target_feature = target_feature.cpu().numpy()
    source_labels = source_labels.cpu().numpy()
    target_labels = target_labels.cpu().numpy()

    # Combine features and labels
    features = np.concatenate([source_feature, target_feature], axis=0)
    labels = np.concatenate([source_labels, target_labels], axis=0)
    domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature))))

    # Map features to 2-D using t-SNE
    X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features)

    # Visualize using matplotlib
    plt.figure(figsize=(10, 10))

    # Get unique class labels
    unique_labels = np.unique(labels)

    # Create a color map for classes
    #cmap = plt.get_cmap('tab20', len(unique_labels))
    cmap_s = plt.get_cmap('nipy_spectral', len(unique_labels))
    cmap_r = plt.get_cmap('gist_rainbow', len(unique_labels))

    # Plot data points for each class and domain
    for label in unique_labels:
        for domain in [0, 1]:
            mask = (labels == label) & (domains == domain)
            plt.scatter(X_tsne[mask, 0], X_tsne[mask, 1], c=cmap_s(label), s=10,
                        #label=f"Class {label}, Domain {domain}",
                        )

    plt.legend()
    plt.savefig(filename)

# Example usage
# visualize(source_feature, target_feature, source_labels, target_labels, 'tsne_plot.png')



###########################################################################################
checkpoint = torch.load('/home/crueang/Chaks/AIOT/5_1_homework/checkpoint/inceptionV3/checkpoint_inceptionV3_finetune/best_model.pth', weights_only=True)
loaded_resume_model = InceptionV3Net().to(device)
loaded_resume_model.fc3 = nn.Linear(512, 10).to(device)

loaded_resume_model.load_state_dict(checkpoint, strict=False)


source_feature, s_labels = collect_feature(validation_loader, loaded_resume_model, device)
# target_feature, t_labels = collect_feature(target_dl, feature_extractor, device)

# --- plot t-SNE
if not os.path.exists('/home/crueang/Chaks/AIOT/5_1_homework/output'): os.mkdir('/home/crueang/Chaks/AIOT/5_1_homework/output')
tSNE_filename = osp.join('/home/crueang/Chaks/AIOT/5_1_homework/output', 'W5-1__resumev3_tSNE.png')
visualize_class_n_domain(source_feature, source_feature, s_labels, s_labels, tSNE_filename)    # single-domain multi-class rep
# visualize_class_n_domain(source_feature, target_feature, s_labels, t_labels, tSNE_filename)    # two-domain multi-class rep
print("Saving t-SNE to", tSNE_filename)


#--- Confusion matrix, F1-score, precision, recall, NMI/RI scores
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, normalized_mutual_info_score, adjusted_rand_score#, f1_score
# Class labels
pos_labels = np.arange(10)
lb_classes = ('0', '1', '2', '3', '4', '5', '6' ,'7', '8', '9')

lb, prd = get_labels_predictions(loaded_resume_model, validation_loader, device)
#---confusion matrix
cm_target = confusion_matrix(y_true=lb,
                            y_pred=prd,
                            labels=pos_labels,
                            normalize='true',
                            )
plt.figure()
plt.rcParams.update({'font.size': 10, 'figure.figsize': (2,2)})
plot_confusion_matrix(cm_target,
                        classes=lb_classes,
                        normalize=True,
                        title='Conf. Mat. w.r.t. STL-10 ds',
                        cmap=plt.cm.binary #Blues_r
                        )    #Blues_r = off-white diagonal
#---F1-score/Precision/Recall scores
print("Precision/Recall/F-beta score:", precision_recall_fscore_support(lb, prd, average='weighted', zero_division=0,
                                          beta=1.0)) #labels=label_classes))
#---Normalized Mutual Information (NMI) score
nmi_score = normalized_mutual_info_score(labels_true=lb,
                                        labels_pred=prd,
                                        average_method='arithmetic',
                                        )
#---Rand Index (RI) score
ri_score = adjusted_rand_score(labels_true=lb,
                                labels_pred=prd,
                                )
print(f"NMI score: {nmi_score}, RI score: {ri_score}")

UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
 Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.