In [128]:
from typing import List
from torchvision import transforms, models, datasets
import torch 
from torch.utils.data import DataLoader, Subset, random_split, ConcatDataset
import numpy as np 
import random 
from os.path import exists
import os 
from typing import List, Dict, Tuple
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import accuracy_score, confusion_matrix
import torch.nn as nn
from torch.functional import F
from tqdm import trange
import copy

In [129]:
def get_oxford_splits(
    batch_size: int,
    data_loader_seed: int = 111, 
    pin_memory: bool = False,
    num_workers: int = 1,
    ): 
    K = 5
    num_support = 80
    num_query = 20 

    def seed_worker(worker_id):
        # worker_seed = torch.initial_seed() % 2 ** 32
        np.random.seed(data_loader_seed)
        random.seed(data_loader_seed)
    g = torch.Generator()
    g.manual_seed(data_loader_seed)

    support_classes = list(np.arange(num_support))
    query_classes = list(np.arange(num_query) + num_support)

    img_dim = 64

    train_transforms = transforms.Compose([
        transforms.Resize((img_dim, img_dim)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
    test_transforms = transforms.Compose([
        transforms.Resize((img_dim, img_dim)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
    validation_transforms = transforms.Compose([
        transforms.Resize((img_dim, img_dim)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
    
    data_path = f'/home/hesam/projects/Phase4/dataset'

    train_ds_full = datasets.Flowers102(root=data_path, split="train", download=True, transform=train_transforms)
    val_ds_full = datasets.Flowers102(root=data_path, split="val", download=True, transform=validation_transforms)
    test_ds_full = datasets.Flowers102(root=data_path, split="test", download=True, transform=test_transforms)

    train_indxs_support = torch.where(torch.isin(torch.tensor(train_ds_full._labels), torch.asarray(support_classes)))[0]
    val_indxs_support = torch.where(torch.isin(torch.tensor(val_ds_full._labels), torch.asarray(support_classes)))[0]
    test_indxs_support = torch.where(torch.isin(torch.tensor(test_ds_full._labels), torch.asarray(support_classes)))[0]
    
    train_ds_subset_support = torch.utils.data.Subset(train_ds_full, train_indxs_support)
    val_ds_subset_support = torch.utils.data.Subset(val_ds_full, val_indxs_support)
    test_ds_subset_support = torch.utils.data.Subset(test_ds_full, test_indxs_support)

    merged_dataset = ConcatDataset([train_ds_subset_support, val_ds_subset_support, test_ds_subset_support])
    ### A, B
    train_ds_support, test_ds_support = torch.utils.data.random_split(merged_dataset, [0.75, 0.25], generator=torch.Generator().manual_seed(42))
    ### 

    train_indxs_query = torch.where(torch.isin(torch.tensor(train_ds_full._labels), torch.asarray(query_classes)))[0]
    N = 10 
    starting_indices = np.arange(0, len(train_indxs_query), N)
    train_indxs_query = np.hstack([train_indxs_query[i:i+K] for i in starting_indices if i + K <= len(train_indxs_query)])
    ### C
    train_ds_query = torch.utils.data.Subset(train_ds_full, train_indxs_query) 
    ###

    val_indxs_query = torch.where(torch.isin(torch.tensor(val_ds_full._labels), torch.asarray(query_classes)))[0]
    test_indxs_query = torch.where(torch.isin(torch.tensor(test_ds_full._labels), torch.asarray(query_classes)))[0]
    val_ds_subset_query = torch.utils.data.Subset(val_ds_full, val_indxs_query)
    test_ds_subset_query = torch.utils.data.Subset(test_ds_full, test_indxs_query)
    
    test_ds_query_full = ConcatDataset([val_ds_subset_query, test_ds_subset_query])
    ### D 
    _, test_ds_query = torch.utils.data.random_split(test_ds_query_full, [0.7, 0.3], generator=torch.Generator().manual_seed(42))
    ###
    
    ### E 
    test_all = ConcatDataset([test_ds_query, test_ds_support])


    A_train_dl = DataLoader(
        train_ds_support,
        batch_size = batch_size, 
        shuffle=True, 
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=False,
        pin_memory=pin_memory,
        num_workers=num_workers
    )
    A_test_dl = DataLoader(
        test_ds_support,
        batch_size = batch_size, 
        shuffle=True, 
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=False,
        pin_memory=pin_memory,
        num_workers=num_workers
    )

    B_train_dl = DataLoader(
        train_ds_query,
        batch_size = batch_size, 
        shuffle=True, 
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=False,
        pin_memory=pin_memory,
        num_workers=num_workers
    )
    B_test_dl = DataLoader(
        test_ds_query,
        batch_size = batch_size, 
        shuffle=True, 
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=False,
        pin_memory=pin_memory,
        num_workers=num_workers
    )
    test_all = DataLoader(
        test_all,
        batch_size = batch_size, 
        shuffle=True, 
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=False,
        pin_memory=pin_memory,
        num_workers=num_workers
    )
    
    return A_train_dl, A_test_dl, B_train_dl, B_test_dl, test_all

In [130]:
def make_dir(dir_name: str):
    """
    creates directory "dir_name" if it doesn't exists
    """
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

In [131]:
def plot_conf(labels, preds, title, dir_, name):
    """
    labels: an [N, ] array containing true labels for N samples
    preds: an [N, ] array containing predications for N samples
    
    saves confusion matrix plot of the given prediction and true labels in 'dir_/name.jpg' 
    """

    conf = confusion_matrix(labels, preds)
    # print("count", torch.unique(preds, return_counts=True))
    plt.clf()
    cm = conf.astype('float') / conf.sum(axis=1)[:, np.newaxis]
    cmap = sns.light_palette("navy", as_cmap=True)
    plt.figure(figsize=(20, 20))
    sns.heatmap(cm, annot=False, cmap=cmap, fmt=".2f", cbar=False)
    # plt.title(f'{title}')
    # plt.xlabel('Predicted Label')
    # plt.ylabel('True Label')
    make_dir(dir_)
    plt.savefig(f'{dir_}/{name}')

In [132]:
A_train_dl, A_test_dl, B_train_dl, B_test_dl, test_all = get_oxford_splits(128)

In [133]:
use_gpu = True
device = torch.device("cuda:0" if torch.cuda.is_available() and use_gpu else "cpu") # New


#### define network with layers

In [134]:
class My_Model(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        
        self.flatten = nn.Flatten()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), padding=1, stride=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=1, stride=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=1, stride=1)
        self.bn2_1 = nn.BatchNorm2d(64)
        
        self.conv2_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=1, stride=1)
        self.bn2_2 = nn.BatchNorm2d(64)
        
        self.conv2_3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=1, stride=1)
        self.bn2_3 = nn.BatchNorm2d(64)
        
        self.pool1 = nn.AvgPool2d(kernel_size=(2, 2))
        
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=96, kernel_size=(3, 3), padding=1, stride=1)
        self.bn3 = nn.BatchNorm2d(96)
        
        self.conv3_1 = nn.Conv2d(in_channels=96, out_channels=96, kernel_size=(3, 3), padding=1, stride=1)
        self.bn3_1 = nn.BatchNorm2d(96)
        
        self.conv3_2 = nn.Conv2d(in_channels=96, out_channels=96, kernel_size=(3, 3), padding=1, stride=1)
        self.bn3_2 = nn.BatchNorm2d(96)
        
        self.conv3_3 = nn.Conv2d(in_channels=96, out_channels=96, kernel_size=(3, 3), padding=1, stride=1)
        self.bn3_3 = nn.BatchNorm2d(96)
        
        self.pool2 = nn.AvgPool2d(kernel_size=(2, 2))
        
        self.conv4 = nn.Conv2d(in_channels=96, out_channels=128, kernel_size=(3, 3), padding=1, stride=1)
        self.bn4 = nn.BatchNorm2d(128)
        
        self.conv4_1 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, stride=1)
        self.bn4_1 = nn.BatchNorm2d(128)
        
        self.conv4_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, stride=1)
        self.bn4_2 = nn.BatchNorm2d(128)
        
        self.conv4_3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, stride=1)
        self.bn4_3 = nn.BatchNorm2d(128)        
        
        self.pool3 = nn.AvgPool2d(kernel_size=(2, 2))

        self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=1, stride=1)
        self.bn5 = nn.BatchNorm2d(256)
        
        self.conv5_1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), padding=1, stride=1)
        self.bn5_1 = nn.BatchNorm2d(256)
        
        self.conv5_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), padding=1, stride=1)
        self.bn5_2 = nn.BatchNorm2d(256)
        
        self.conv5_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), padding=1, stride=1)
        self.bn5_3 = nn.BatchNorm2d(256)


        self.pool4 = nn.AvgPool2d(kernel_size=(2, 2))
        
        self.fc1 = nn.Linear(in_features=4096, out_features=80)
        
    def forward(self, inputs, debug=False):
        x1 = self.conv1(inputs)
        x1 = self.bn1(x1)
        x1 = F.relu(x1)
        
        x2 = self.conv2(x1)
        x2 = self.bn2(x2)
        x2 = F.relu(x2)
        
        x2 = self.conv2_1(x2)
        x2 = self.bn2_1(x2)
        x2 = F.relu(x2)
        
        x2 = self.conv2_2(x2)
        x2 = self.bn2_2(x2)
        x2 = F.relu(x2)
        
        x2 = self.conv2_3(x2)
        x2 = self.bn2_3(x2)
        x2 = F.relu(x2)  

        x2 = self.pool1(x2)

        x3 = self.conv3(x2)
        x3 = self.bn3(x3)
        x3 = F.relu(x3)
        
        x3 = self.conv3_1(x3)
        x3 = self.bn3_1(x3)
        x3 = F.relu(x3)
        
        x3 = self.conv3_2(x3)
        x3 = self.bn3_2(x3)
        x3 = F.relu(x3)
        
        x3 = self.conv3_3(x3)
        x3 = self.bn3_3(x3)
        x3 = F.relu(x3)
    

        x3 = self.pool2(x3)

        x4 = self.conv4(x3)
        x4 = self.bn4(x4)
        x4 = F.relu(x4)
        
        x4 = self.conv4_1(x4)
        x4 = self.bn4_1(x4)
        x4 = F.relu(x4)
        
        x4 = self.conv4_2(x4)
        x4 = self.bn4_2(x4)
        x4 = F.relu(x4)
        
        x4 = self.conv4_3(x4)
        x4 = self.bn4_3(x4)
        x4 = F.relu(x4)

        x4 = self.pool3(x4)

        x5 = self.conv5(x4)
        x5 = self.bn5(x5)
        x5 = F.relu(x5)
        
        x5 = self.conv5_1(x5)
        x5 = self.bn5_1(x5)
        x5 = F.relu(x5)
        
        x5 = self.conv5_2(x5)
        x5 = self.bn5_2(x5)
        x5 = F.relu(x5)
        
        x5 = self.conv5_3(x5)
        x5 = self.bn5_3(x5)
        x5 = F.relu(x5)
  

        x6 = self.pool4(x5)
        x0 = self.flatten(x6)
        outputs = self.fc1(x0)
        
        if debug:
                    print('inputs shape: ', inputs.shape) # inputs in shape [N, C, H, W]
                    print('after flattening: ', x0.shape)
                    print('Activations after 1st fully connected layer: ', x1.shape)
                    print('Activations after 2nd fully connected layer: ', x2.shape)
                    print('Output shape: ', outputs.shape)

        return outputs
       

In [137]:
def train_one_epoch(modelIN: nn.Module, optim: torch.optim.Optimizer,
         dataloader: DataLoader, loss_fn, freeze_fc=False):

    num_samples = len(dataloader.dataset)
    num_batches = len(dataloader)
    running_corrects = 0
    running_loss = 0.0
    modelIN.train() #
    for batch_indx, (inputs, targets) in enumerate(dataloader): # Get a batch of Data
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        outputs = modelIN(inputs) # Forward Pass, [N, 10]
        loss = loss_fn(outputs, targets) # Compute Loss

        loss.backward() # Compute Gradients
        if freeze_fc:
            for param in modelIN.fc1.parameters():
                param.grad[:80] = 0 
            
        optim.step() # Update parameters
        optim.zero_grad() # zero the parameter's 

        _, preds = torch.max(outputs, dim=1) # Explain, [N]
        # print("pred", preds, targets)
        running_corrects += torch.sum(preds == targets)
        running_loss += loss.item()

    epoch_acc = (running_corrects / num_samples) * 100
    epoch_loss = (running_loss / num_batches)

    return epoch_acc, epoch_loss

In [138]:
count = 0

In [139]:
def test_model(modelIN: nn.Module,
         dataloader: DataLoader, loss_fn, dir, cpu=False, cm=False):

    # utils
    num_samples = len(dataloader.dataset)
    num_batches = len(dataloader)
    running_corrects = 0
    running_loss = 0.0
    predictions = []
    labels = []
    modelIN.eval() # you must call `model.eval()` to set dropout and batch normalization layers to evaluation mode before running inference.
    with torch.no_grad(): # explain
        # more on torch.no_grad(): https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html#disabling-gradient-tracking

        for batch_indx, (inputs, targets) in enumerate(dataloader): # Get a batch of Data
            if cpu:
                inputs = inputs.to("cpu")
                targets = targets.to("cpu")
            else:
                inputs = inputs.to(device)
                targets = targets.to(device)
            
            outputs = modelIN(inputs) # Forward Pass
            loss = loss_fn(outputs, targets) # Compute Loss

            # loss.backward() # Compute Gradients
            # optim.step() # Update parameters
            # optim.zero_grad() # zero the parameter's gradients

            _, preds = torch.max(outputs, 1) #
            running_corrects += torch.sum(preds == targets)
            running_loss += loss.item()
            predictions.append(preds.cpu())
            labels.append(targets.cpu())
    # cm flag to plot the cofussion matrix
    if cm:
        global count
        count += 1
        predictions = np.concatenate(predictions, axis=0).flatten()
        labels = np.concatenate(labels, axis=0).flatten()
        plot_conf(labels, predictions, title="CM", dir_=dir, name=f"confusion_matrix-{count}")
    test_acc = (running_corrects / num_samples) * 100
    test_loss = (running_loss / num_batches)

    return test_acc, test_loss

In [140]:
def custom_plot_training_stats(
        acc_hist,
        loss_hist,
        phase_list,
        title: str,
        dir: str,
        index : int,
        name: str = 'acc_loss'
        ):
    fig, (ax1, ax2) = plt.subplots(nrows = 1, ncols = 2, figsize=[14, 6], dpi=100)

    for phase in phase_list:

        lowest_loss_x = np.argmin(np.array(loss_hist[phase]))
        lowest_loss_y = loss_hist[phase][lowest_loss_x]

        ax1.annotate("{:.4f}".format(lowest_loss_y), [lowest_loss_x, lowest_loss_y])
        ax1.plot(loss_hist[phase], '-x', label=f'{phase} loss', markevery = [lowest_loss_x])

        ax1.set_xlabel(xlabel='epochs')
        ax1.set_ylabel(ylabel='loss')

        ax1.grid(color = 'green', linestyle = '--', linewidth = 0.5, alpha=0.75)
        ax1.legend()
        ax1.label_outer()

    # acc:
    for phase in phase_list:
        highest_acc_x = np.argmax(np.array(acc_hist[phase]))
        highest_acc_y = acc_hist[phase][highest_acc_x]

        ax2.annotate("{:.4f}".format(highest_acc_y), [highest_acc_x, highest_acc_y])
        ax2.plot(acc_hist[phase], '-x', label=f'{phase} acc', markevery = [highest_acc_x])

        ax2.set_xlabel(xlabel='epochs')
        ax2.set_ylabel(ylabel='acc')

        ax2.grid(color = 'green', linestyle = '--', linewidth = 0.5, alpha=0.75)
        ax2.legend()
        #ax2.label_outer()

    fig.suptitle(f'{title}')

    plt.savefig(f'chart/{name}-{index}.jpg')
    plt.clf()

In [141]:
def run(train_data, test_data, input_model, index, epochs, lr, test=False, freeze_fc=False, cpu=False, cm=False, dir=None):
    batch_size = 32
    num_epochs = epochs
    learning_rate = lr
    
    model = input_model
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    cross_entropy = nn.CrossEntropyLoss()

    acc_history = {'train': [], 'test': []}
    loss_history = {'train': [], 'test': []}

    for epoch in trange(num_epochs):
        if not test:
            train_acc, train_loss = train_one_epoch(modelIN=model, optim=optimizer, dataloader=train_data, loss_fn=cross_entropy, freeze_fc=freeze_fc)
            print("train: ", train_acc.cpu(), train_loss)
            acc_history['train'].append(train_acc.cpu())
            loss_history['train'].append(train_loss)
        test_acc, test_loss = test_model(modelIN=model, dataloader=test_data, loss_fn=cross_entropy, dir=dir, cpu=cpu, cm=cm)
        print("test: ", test_acc.cpu(), test_loss)
        acc_history['test'].append(test_acc.cpu())
        loss_history['test'].append(test_loss)
        
    custom_plot_training_stats(acc_history, loss_history, ['train', 'test'], title='demp', dir='demo_plots', index=index)


#### train the original model

In [142]:
my_model = My_Model()
run(A_train_dl, A_test_dl, my_model, 1, 90, 0.0001)

  0%|          | 0/90 [00:00<?, ?it/s]

train:  tensor(13.6669) 3.7138580438253044


  1%|          | 1/90 [00:14<21:04, 14.20s/it]

test:  tensor(7.2172) 4.1654169376079855
train:  tensor(31.1241) 2.777406273661433


  2%|▏         | 2/90 [00:27<19:37, 13.38s/it]

test:  tensor(34.6554) 2.547719258528489
train:  tensor(43.0583) 2.230035443563719


  3%|▎         | 3/90 [00:39<19:02, 13.13s/it]

test:  tensor(41.4174) 2.277535080909729
train:  tensor(52.0684) 1.828755572035506


  4%|▍         | 4/90 [00:52<18:36, 12.98s/it]

test:  tensor(48.8947) 1.9947178638898408
train:  tensor(59.6708) 1.5474707274823576


  6%|▌         | 5/90 [01:05<18:17, 12.91s/it]

test:  tensor(48.5046) 1.9408813623281627
train:  tensor(64.5657) 1.343338206007674


  7%|▋         | 6/90 [01:18<17:59, 12.85s/it]

test:  tensor(52.7958) 1.626639696267935
train:  tensor(70.6303) 1.131878846400493


  8%|▊         | 7/90 [01:30<17:38, 12.76s/it]

test:  tensor(54.9415) 1.640696488893949
train:  tensor(75.3303) 0.9667833544112541


  9%|▉         | 8/90 [01:43<17:20, 12.68s/it]

test:  tensor(61.5085) 1.3236527030284588
train:  tensor(79.4455) 0.8600353453610394


 10%|█         | 9/90 [01:55<17:07, 12.68s/it]

test:  tensor(59.4278) 1.5195385676163893
train:  tensor(81.4598) 0.7651099031035965


 11%|█         | 10/90 [02:08<16:55, 12.70s/it]

test:  tensor(61.8335) 1.367955758021428
train:  tensor(83.5824) 0.686224415495589


 12%|█▏        | 11/90 [02:21<16:45, 12.73s/it]

test:  tensor(62.4187) 1.2991699828551366
train:  tensor(88.9755) 0.5309596738299808


 13%|█▎        | 12/90 [02:34<16:32, 12.72s/it]

test:  tensor(64.6944) 1.5089779266944299
train:  tensor(87.9359) 0.5213169866317028


 14%|█▍        | 13/90 [02:46<16:16, 12.69s/it]

test:  tensor(65.6047) 1.1538860614483173
train:  tensor(89.0622) 0.4768502857234027


 16%|█▌        | 14/90 [02:59<16:04, 12.69s/it]

test:  tensor(59.3628) 1.6532095762399526
train:  tensor(90.9898) 0.41119616539091675


 17%|█▋        | 15/90 [03:12<15:55, 12.74s/it]

test:  tensor(67.2952) 1.102207868718184
train:  tensor(94.6285) 0.2986086409639668


 18%|█▊        | 16/90 [03:25<15:43, 12.75s/it]

test:  tensor(67.6853) 1.0851556223172407
train:  tensor(94.3903) 0.3270753259594376


 19%|█▉        | 17/90 [03:38<15:36, 12.83s/it]

test:  tensor(62.8739) 1.4999559475825384
train:  tensor(92.9825) 0.3360032757391801


 20%|██        | 18/90 [03:51<15:25, 12.86s/it]

test:  tensor(68.0104) 1.0829934500730956
train:  tensor(96.5129) 0.20916004277564385


 21%|██        | 19/90 [04:03<15:09, 12.81s/it]

test:  tensor(69.9610) 1.3651714829298167
train:  tensor(97.6175) 0.16965600646830895


 22%|██▏       | 20/90 [04:16<14:58, 12.84s/it]

test:  tensor(71.0663) 1.1028153942181513
train:  tensor(97.1410) 0.19186090617566495


 23%|██▎       | 21/90 [04:29<14:49, 12.89s/it]

test:  tensor(69.7659) 1.0509891143211951
train:  tensor(97.3143) 0.16542889359029564


 24%|██▍       | 22/90 [04:42<14:37, 12.90s/it]

test:  tensor(71.7165) 0.993669834274512
train:  tensor(98.5272) 0.11565570754779352


 26%|██▌       | 23/90 [04:55<14:18, 12.82s/it]

test:  tensor(72.8869) 0.92861674955258
train:  tensor(98.8304) 0.12433601029821344


 27%|██▋       | 24/90 [05:07<14:03, 12.78s/it]

test:  tensor(66.9701) 1.3296951422324548
train:  tensor(96.7295) 0.18559916559103373


 28%|██▊       | 25/90 [05:20<13:49, 12.76s/it]

test:  tensor(72.9519) 1.1918590710713313
train:  tensor(99.0253) 0.10201935157985301


 29%|██▉       | 26/90 [05:33<13:38, 12.79s/it]

test:  tensor(73.9272) 0.9136667626981552
train:  tensor(97.9857) 0.14109106220909068


 30%|███       | 27/90 [05:46<13:24, 12.76s/it]

test:  tensor(72.1717) 1.0537490477928748
train:  tensor(99.4152) 0.07887122296803706


 31%|███       | 28/90 [05:58<13:10, 12.75s/it]

test:  tensor(70.4811) 1.0400560177289522
train:  tensor(95.3216) 0.22604145894984942


 32%|███▏      | 29/90 [06:11<12:57, 12.74s/it]

test:  tensor(72.4317) 1.0332441559204688
train:  tensor(99.1986) 0.07858444267028086


 33%|███▎      | 30/90 [06:24<12:56, 12.94s/it]

test:  tensor(72.4317) 0.9959044800354884
train:  tensor(99.3719) 0.07036752797461845


 34%|███▍      | 31/90 [06:38<12:45, 12.98s/it]

test:  tensor(73.1469) 0.9720059922681406
train:  tensor(98.9387) 0.10496648521842183


 36%|███▌      | 32/90 [06:51<12:34, 13.01s/it]

test:  tensor(72.4967) 0.9388530121113245
train:  tensor(98.4189) 0.10606877999128522


 37%|███▋      | 33/90 [07:03<12:19, 12.97s/it]

test:  tensor(73.1469) 1.0934610733619103
train:  tensor(97.5309) 0.1695399699178902


 38%|███▊      | 34/90 [07:16<12:03, 12.93s/it]

test:  tensor(66.9051) 1.4270570736664991
train:  tensor(93.1341) 0.28509452596709534


 39%|███▉      | 35/90 [07:29<11:49, 12.91s/it]

test:  tensor(69.3758) 1.158681952036344
train:  tensor(98.6138) 0.09742195370632249


 40%|████      | 36/90 [07:42<11:35, 12.89s/it]

test:  tensor(73.5371) 0.9470580605646739
train:  tensor(99.1770) 0.06430270507730343


 41%|████      | 37/90 [07:55<11:21, 12.86s/it]

test:  tensor(73.9922) 0.9026953400327609
train:  tensor(99.9567) 0.028279752797774366


 42%|████▏     | 38/90 [08:08<11:07, 12.84s/it]

test:  tensor(76.3979) 0.8350235941604927
train:  tensor(98.8304) 0.074356256405244


 43%|████▎     | 39/90 [08:20<10:54, 12.83s/it]

test:  tensor(73.8622) 1.006685353242434
train:  tensor(98.8737) 0.08025768545229693


 44%|████▍     | 40/90 [08:34<10:47, 12.95s/it]

test:  tensor(73.7971) 0.9359503407031298
train:  tensor(98.8954) 0.07163764294740316


 46%|████▌     | 41/90 [08:47<10:46, 13.18s/it]

test:  tensor(73.4720) 0.9860666210834796
train:  tensor(99.5018) 0.048723930392313645


 47%|████▋     | 42/90 [09:01<10:35, 13.24s/it]

test:  tensor(77.0481) 0.8560367547548734
train:  tensor(99.8051) 0.02537513051081348


 48%|████▊     | 43/90 [09:13<10:12, 13.02s/it]

test:  tensor(76.2679) 0.8559430287434504
train:  tensor(99.9783) 0.01339249568362091


 49%|████▉     | 44/90 [09:26<09:51, 12.85s/it]

test:  tensor(76.5930) 0.8442956392581646
train:  tensor(99.8267) 0.0225107193848974


 50%|█████     | 45/90 [09:38<09:33, 12.76s/it]

test:  tensor(76.0728) 0.8360181015271407
train:  tensor(99.3935) 0.03940932039876242


 51%|█████     | 46/90 [09:51<09:18, 12.69s/it]

test:  tensor(75.1625) 1.4906091048167303
train:  tensor(99.9350) 0.027703185125279264


 52%|█████▏    | 47/90 [10:03<09:03, 12.63s/it]

test:  tensor(75.9428) 1.0025635590920081
train:  tensor(98.8521) 0.07323419098817818


 53%|█████▎    | 48/90 [10:16<08:51, 12.66s/it]

test:  tensor(73.6671) 0.9365312557380933
train:  tensor(99.7184) 0.05878472204848721


 54%|█████▍    | 49/90 [10:29<08:49, 12.91s/it]

test:  tensor(60.5332) 1.637481187398617
train:  tensor(88.2824) 0.4444113621437872


 56%|█████▌    | 50/90 [10:43<08:42, 13.06s/it]

test:  tensor(63.9792) 1.4450798401465783
train:  tensor(96.1014) 0.16189338166165995


 57%|█████▋    | 51/90 [10:57<08:40, 13.35s/it]

test:  tensor(73.2120) 0.9493617969699419
train:  tensor(98.2240) 0.10132665660332989


 58%|█████▊    | 52/90 [11:09<08:17, 13.10s/it]

test:  tensor(74.1222) 1.0286013575700612
train:  tensor(99.4369) 0.05271259936931971


 59%|█████▉    | 53/90 [11:22<07:56, 12.89s/it]

test:  tensor(74.9025) 1.072484314441681
train:  tensor(98.1806) 0.10377411826236828


 60%|██████    | 54/90 [11:35<07:42, 12.86s/it]

test:  tensor(73.2770) 1.0450071738316462
train:  tensor(97.0760) 0.12853241278915792


 61%|██████    | 55/90 [11:49<07:42, 13.21s/it]

test:  tensor(73.6671) 1.1885017569248493
train:  tensor(99.1336) 0.05474364306979083


 62%|██████▏   | 56/90 [12:05<08:04, 14.24s/it]

test:  tensor(75.8778) 1.1625387439360986
train:  tensor(99.9350) 0.015054259499585306


 63%|██████▎   | 57/90 [12:19<07:49, 14.22s/it]

test:  tensor(77.9584) 0.8245829701996766
train:  tensor(99.9567) 0.011592040404778075


 64%|██████▍   | 58/90 [12:36<07:53, 14.80s/it]

test:  tensor(77.6983) 1.185708252283243
train:  tensor(99.9134) 0.011550282077813471


 66%|██████▌   | 59/90 [12:49<07:28, 14.46s/it]

test:  tensor(77.5683) 0.8446027951744887
train:  tensor(99.9350) 0.024133913810490758


 67%|██████▋   | 60/90 [13:03<07:03, 14.11s/it]

test:  tensor(75.0975) 0.9268425519649799
train:  tensor(98.1590) 0.10876355151570327


 68%|██████▊   | 61/90 [13:18<07:00, 14.49s/it]

test:  tensor(72.8869) 1.355717842395489
train:  tensor(98.6355) 0.07215739162387075


 69%|██████▉   | 62/90 [13:32<06:38, 14.25s/it]

test:  tensor(75.6177) 0.9125682135614065
train:  tensor(99.8267) 0.0206706174271735


 70%|███████   | 63/90 [13:45<06:16, 13.96s/it]

test:  tensor(75.4226) 0.9011656985833094
train:  tensor(99.3719) 0.035902513405056416


 71%|███████   | 64/90 [13:59<06:01, 13.90s/it]

test:  tensor(75.6177) 0.8883448942349508
train:  tensor(99.9567) 0.031944618094712496


 72%|███████▏  | 65/90 [14:12<05:39, 13.60s/it]

test:  tensor(76.3329) 0.9069565666409639
train:  tensor(98.2023) 0.10400046912852574


 73%|███████▎  | 66/90 [14:26<05:30, 13.76s/it]

test:  tensor(68.2705) 1.3198027060582087
train:  tensor(98.1590) 0.10247473992608688


 74%|███████▍  | 67/90 [14:41<05:24, 14.09s/it]

test:  tensor(73.2770) 1.0142773940013006
train:  tensor(99.6535) 0.03218795690131751


 76%|███████▌  | 68/90 [14:55<05:09, 14.09s/it]

test:  tensor(72.6918) 1.097326086117671
train:  tensor(96.9894) 0.14222990832216031


 77%|███████▋  | 69/90 [15:08<04:54, 14.00s/it]

test:  tensor(74.6424) 0.9870876371860504
train:  tensor(99.3502) 0.039603477293575134


 78%|███████▊  | 70/90 [15:23<04:43, 14.16s/it]

test:  tensor(77.2432) 0.9262764774836026
train:  tensor(99.6968) 0.024892290022123505


 79%|███████▉  | 71/90 [15:38<04:32, 14.34s/it]

test:  tensor(77.9584) 0.9356433428250827
train:  tensor(100.) 0.009886012223826067


 80%|████████  | 72/90 [15:52<04:16, 14.25s/it]

test:  tensor(78.0884) 0.8636644207514249
train:  tensor(99.5235) 0.027756949811167008


 81%|████████  | 73/90 [16:07<04:04, 14.39s/it]

test:  tensor(77.0481) 0.9248877992996802
train:  tensor(99.8700) 0.011764458430981313


 82%|████████▏ | 74/90 [16:22<03:54, 14.63s/it]

test:  tensor(79.1938) 1.0441315815998957
train:  tensor(99.9134) 0.020144205540418625


 83%|████████▎ | 75/90 [16:36<03:38, 14.55s/it]

test:  tensor(75.1625) 0.9377765002158972
train:  tensor(96.7728) 0.13187743422600465


 84%|████████▍ | 76/90 [16:50<03:19, 14.26s/it]

test:  tensor(73.3420) 1.0159879779586425
train:  tensor(98.6788) 0.06878877685380143


 86%|████████▌ | 77/90 [17:04<03:04, 14.22s/it]

test:  tensor(73.9272) 1.000814743196735
train:  tensor(98.0074) 0.09566324168967234


 87%|████████▋ | 78/90 [17:17<02:48, 14.05s/it]

test:  tensor(73.6021) 0.9799436846604714
train:  tensor(99.7401) 0.02739059619253149


 88%|████████▊ | 79/90 [17:32<02:35, 14.15s/it]

test:  tensor(74.7074) 0.9657036782457278
train:  tensor(98.7654) 0.060755516503106906


 89%|████████▉ | 80/90 [17:44<02:16, 13.69s/it]

test:  tensor(74.2523) 0.9542955747411515
train:  tensor(99.8700) 0.01664341276360525


 90%|█████████ | 81/90 [17:57<01:59, 13.33s/it]

test:  tensor(78.0884) 0.8941660890212426
train:  tensor(99.8700) 0.01377656901053883


 91%|█████████ | 82/90 [18:10<01:44, 13.11s/it]

test:  tensor(77.6333) 0.8895146938470694
train:  tensor(99.9350) 0.006999896757455694


 92%|█████████▏| 83/90 [18:22<01:30, 12.96s/it]

test:  tensor(77.8283) 0.9094048004883987
train:  tensor(99.9134) 0.010793028846131387


 93%|█████████▎| 84/90 [18:35<01:17, 12.89s/it]

test:  tensor(77.9584) 0.8759379148339996
train:  tensor(100.) 0.004883966043692183


 94%|█████████▍| 85/90 [18:47<01:03, 12.75s/it]

test:  tensor(78.9987) 0.8307587178603101
train:  tensor(100.) 0.005467076855082367


 96%|█████████▌| 86/90 [19:00<00:50, 12.67s/it]

test:  tensor(78.8687) 0.9067270893317002
train:  tensor(99.5885) 0.017006750558377116


 97%|█████████▋| 87/90 [19:12<00:37, 12.60s/it]

test:  tensor(77.0481) 0.9300932997766023
train:  tensor(100.) 0.0050463350741444405


 98%|█████████▊| 88/90 [19:25<00:25, 12.57s/it]

test:  tensor(78.2185) 0.8915679454803467
train:  tensor(99.9783) 0.003845238557865692


 99%|█████████▉| 89/90 [19:37<00:12, 12.56s/it]

test:  tensor(78.6086) 1.0168933547460115
train:  tensor(99.9567) 0.00510461103938822


100%|██████████| 90/90 [19:50<00:00, 13.23s/it]

test:  tensor(77.2432) 0.8935288511789762





<Figure size 1400x600 with 0 Axes>

##### check the weight of model before training

In [144]:
my_model.conv2_1.weight

Parameter containing:
tensor([[[[-4.1152e-02,  2.7552e-02,  2.3232e-02],
          [-5.0019e-02, -3.7431e-02,  2.7063e-02],
          [ 4.6284e-03, -4.1788e-02,  3.4469e-02]],

         [[ 1.0118e-03, -3.1182e-02, -1.6825e-02],
          [-3.5135e-02, -4.1477e-02,  4.8241e-03],
          [ 1.6821e-02, -2.0892e-02, -2.0368e-02]],

         [[ 2.6120e-02, -2.7894e-02, -1.2092e-02],
          [ 1.3028e-02, -3.5441e-02, -1.2177e-02],
          [-3.2591e-02, -2.6870e-02,  3.9030e-02]],

         ...,

         [[ 2.6396e-02,  3.2726e-02,  7.4274e-03],
          [-8.5307e-03, -2.7306e-02, -3.1041e-02],
          [-1.1025e-02, -1.9495e-02,  2.3737e-02]],

         [[-2.6593e-02,  3.9382e-02, -8.3937e-03],
          [-3.1704e-02, -6.0318e-03,  5.3831e-03],
          [-4.7803e-02, -4.3978e-02,  5.3610e-06]],

         [[-3.1691e-02,  1.3603e-03, -2.7139e-03],
          [ 1.4907e-02, -1.9542e-03,  8.6123e-05],
          [-7.5084e-03, -3.1214e-03,  1.3781e-02]]],


        [[[-1.7927e-02,  1.9769

In [145]:
my_model.fc1.weight

Parameter containing:
tensor([[-0.0148, -0.0265, -0.0253,  ...,  0.0032,  0.0118, -0.0070],
        [ 0.0125,  0.0121,  0.0150,  ..., -0.0289,  0.0135,  0.0117],
        [-0.0210, -0.0359, -0.0215,  ...,  0.0106,  0.0034,  0.0037],
        ...,
        [ 0.0109,  0.0057, -0.0067,  ...,  0.0134, -0.0179,  0.0016],
        [ 0.0017, -0.0018, -0.0309,  ..., -0.0008, -0.0080, -0.0014],
        [-0.0366, -0.0081,  0.0072,  ..., -0.0128, -0.0209,  0.0118]],
       device='cuda:0', requires_grad=True)

# Phase 2

## First Way

In [146]:
def transfer_learning(model, freeze=False):
    """
    Perform transfer learning by modifying the given model.

    Args:
    - model: Pretrained neural network model
    - freeze (bool): Flag to freeze the pretrained weights (default is False)

    Returns:
    - new_model: Modified neural network model for transfer learning
    """
    # Deep copy the model to avoid modifying the original
    copyModel = copy.deepcopy(model)
    # Create a new model instance
    new_model = type(copyModel)()
    
    # Iterate over child modules of the copied model
    for child_name, child_module in copyModel.named_children():
        # Exclude certain layers from transfer learning
        if child_name not in ["fc1", "output"]:
            new_model.add_module(child_name, child_module)
            
    # Freeze the pretrained weights if specified
    if freeze:
        for param in new_model.parameters():
            param.requires_grad = False
    
    # Extract dimensions of input and output layers of the original model
    old_input_classes = copyModel.fc1.in_features
    new_fc_len = copyModel.fc1.out_features + 20
    
    # Extract weights and bias from the original fully connected layer
    old_weights = copyModel.fc1.weight.data
    old_bias = copyModel.fc1.bias.data
            
    # Create a new fully connected layer with modified dimensions
    new_fc = nn.Linear(old_input_classes, new_fc_len)
    
    # Copy pretrained weights and bias to the new fully connected layer
    with torch.no_grad():
        new_fc.weight[:80, :old_input_classes].copy_(old_weights)
        new_fc.bias[:80].copy_(old_bias)
        
    # Add the new fully connected layer to the new model
    new_model.add_module("fc1", new_fc)
    
    return new_model

In [147]:
newModel1 = transfer_learning(model=my_model)

#### weights after transfer

In [149]:
newModel1.conv2_1.weight

Parameter containing:
tensor([[[[-4.1152e-02,  2.7552e-02,  2.3232e-02],
          [-5.0019e-02, -3.7431e-02,  2.7063e-02],
          [ 4.6284e-03, -4.1788e-02,  3.4469e-02]],

         [[ 1.0118e-03, -3.1182e-02, -1.6825e-02],
          [-3.5135e-02, -4.1477e-02,  4.8241e-03],
          [ 1.6821e-02, -2.0892e-02, -2.0368e-02]],

         [[ 2.6120e-02, -2.7894e-02, -1.2092e-02],
          [ 1.3028e-02, -3.5441e-02, -1.2177e-02],
          [-3.2591e-02, -2.6870e-02,  3.9030e-02]],

         ...,

         [[ 2.6396e-02,  3.2726e-02,  7.4274e-03],
          [-8.5307e-03, -2.7306e-02, -3.1041e-02],
          [-1.1025e-02, -1.9495e-02,  2.3737e-02]],

         [[-2.6593e-02,  3.9382e-02, -8.3937e-03],
          [-3.1704e-02, -6.0318e-03,  5.3831e-03],
          [-4.7803e-02, -4.3978e-02,  5.3610e-06]],

         [[-3.1691e-02,  1.3603e-03, -2.7139e-03],
          [ 1.4907e-02, -1.9542e-03,  8.6123e-05],
          [-7.5084e-03, -3.1214e-03,  1.3781e-02]]],


        [[[-1.7927e-02,  1.9769

In [59]:
newModel1.fc1.weight

Parameter containing:
tensor([[-0.0170, -0.0145, -0.0249,  ..., -0.0205, -0.0112, -0.0009],
        [-0.0257,  0.0155, -0.0224,  ...,  0.0049, -0.0210, -0.0069],
        [-0.0129,  0.0178, -0.0125,  ..., -0.0094, -0.0255, -0.0131],
        ...,
        [ 0.0121,  0.0036, -0.0042,  ...,  0.0038, -0.0079,  0.0064],
        [-0.0148,  0.0133,  0.0133,  ...,  0.0150,  0.0030, -0.0145],
        [-0.0070,  0.0070, -0.0134,  ...,  0.0072, -0.0007,  0.0016]],
       requires_grad=True)

### Train on B

In [150]:
run(train_data=B_train_dl, test_data=B_test_dl, input_model=newModel1, index=2, epochs=20, lr=0.0001, test=False)

  0%|          | 0/20 [00:00<?, ?it/s]

train:  tensor(3.) 3.7360117435455322


  5%|▌         | 1/20 [00:01<00:37,  1.95s/it]

test:  tensor(3.8610) 3.3133307456970216
train:  tensor(17.) 3.0067801475524902


 10%|█         | 2/20 [00:03<00:32,  1.81s/it]

test:  tensor(6.1776) 3.1750799655914306
train:  tensor(33.) 2.588594913482666


 15%|█▌        | 3/20 [00:05<00:29,  1.74s/it]

test:  tensor(10.8108) 3.006966733932495
train:  tensor(69.) 2.264235019683838


 20%|██        | 4/20 [00:07<00:27,  1.74s/it]

test:  tensor(17.9537) 2.811112403869629
train:  tensor(81.) 1.9681631326675415


 25%|██▌       | 5/20 [00:08<00:25,  1.73s/it]

test:  tensor(24.3243) 2.6742984771728517
train:  tensor(90.) 1.7198048830032349


 30%|███       | 6/20 [00:10<00:24,  1.72s/it]

test:  tensor(28.7645) 2.5740155220031737
train:  tensor(92.0000) 1.4736251831054688


 35%|███▌      | 7/20 [00:12<00:22,  1.70s/it]

test:  tensor(34.5560) 2.4303266525268556
train:  tensor(93.) 1.3176366090774536


 40%|████      | 8/20 [00:13<00:20,  1.70s/it]

test:  tensor(36.8726) 2.3144909381866454
train:  tensor(96.) 1.1138050556182861


 45%|████▌     | 9/20 [00:15<00:18,  1.69s/it]

test:  tensor(38.9961) 2.2800100326538084
train:  tensor(95.) 0.9692810773849487


 50%|█████     | 10/20 [00:17<00:16,  1.69s/it]

test:  tensor(39.7683) 2.2536521434783934
train:  tensor(98.0000) 0.79969322681427


 55%|█████▌    | 11/20 [00:18<00:15,  1.69s/it]

test:  tensor(40.3475) 2.1762888431549072
train:  tensor(99.0000) 0.6416202783584595


 60%|██████    | 12/20 [00:20<00:13,  1.69s/it]

test:  tensor(40.9266) 2.0172800064086913
train:  tensor(99.0000) 0.5449172854423523


 65%|██████▌   | 13/20 [00:22<00:11,  1.69s/it]

test:  tensor(42.2780) 1.9249638080596925
train:  tensor(99.0000) 0.49250632524490356


 70%|███████   | 14/20 [00:23<00:10,  1.69s/it]

test:  tensor(43.2432) 2.0330743312835695
train:  tensor(100.) 0.4017335772514343


 75%|███████▌  | 15/20 [00:25<00:08,  1.69s/it]

test:  tensor(43.0502) 1.9388537883758545
train:  tensor(100.) 0.3293105363845825


 80%|████████  | 16/20 [00:27<00:06,  1.69s/it]

test:  tensor(43.2432) 1.9018833875656127
train:  tensor(100.) 0.27055078744888306


 85%|████████▌ | 17/20 [00:28<00:05,  1.68s/it]

test:  tensor(44.0154) 1.8495002031326293
train:  tensor(100.) 0.2395671010017395


 90%|█████████ | 18/20 [00:30<00:03,  1.68s/it]

test:  tensor(44.0154) 1.9650354862213135
train:  tensor(100.) 0.19845321774482727


 95%|█████████▌| 19/20 [00:32<00:01,  1.69s/it]

test:  tensor(43.4363) 1.9730603456497193
train:  tensor(100.) 0.16455689072608948


100%|██████████| 20/20 [00:34<00:00,  1.70s/it]

test:  tensor(43.2432) 1.9940651416778565





<Figure size 1400x600 with 0 Axes>

#### weights after train 

In [151]:
newModel1.conv2_1.weight

Parameter containing:
tensor([[[[-0.0412,  0.0273,  0.0228],
          [-0.0505, -0.0379,  0.0266],
          [ 0.0041, -0.0421,  0.0342]],

         [[ 0.0009, -0.0311, -0.0161],
          [-0.0348, -0.0408,  0.0056],
          [ 0.0172, -0.0198, -0.0195]],

         [[ 0.0264, -0.0272, -0.0119],
          [ 0.0135, -0.0348, -0.0117],
          [-0.0320, -0.0265,  0.0389]],

         ...,

         [[ 0.0268,  0.0332,  0.0080],
          [-0.0082, -0.0270, -0.0306],
          [-0.0108, -0.0192,  0.0241]],

         [[-0.0262,  0.0391, -0.0088],
          [-0.0318, -0.0060,  0.0052],
          [-0.0477, -0.0442, -0.0002]],

         [[-0.0310,  0.0020, -0.0021],
          [ 0.0155, -0.0014,  0.0006],
          [-0.0070, -0.0027,  0.0142]]],


        [[[-0.0170,  0.0210,  0.0433],
          [-0.0281,  0.0368,  0.0209],
          [ 0.0300,  0.0135,  0.0324]],

         [[-0.0160,  0.0275, -0.0373],
          [-0.0280, -0.0265, -0.0431],
          [ 0.0223, -0.0215, -0.0303]],

         

###

In [152]:
newModel1.fc1.weight

Parameter containing:
tensor([[-0.0158, -0.0277, -0.0263,  ...,  0.0020,  0.0107, -0.0081],
        [ 0.0110,  0.0106,  0.0136,  ..., -0.0302,  0.0122,  0.0104],
        [-0.0224, -0.0372, -0.0229,  ...,  0.0094,  0.0021,  0.0028],
        ...,
        [-0.0052, -0.0069,  0.0171,  ...,  0.0010,  0.0077,  0.0102],
        [ 0.0084, -0.0062, -0.0088,  ..., -0.0137, -0.0104,  0.0009],
        [ 0.0014, -0.0123,  0.0134,  ..., -0.0027, -0.0044,  0.0150]],
       device='cuda:0', requires_grad=True)

#### test on A test after train on B

In [75]:
nm = newModel1.to("cpu")
acc, _ = test_model(nm, A_test_dl, nn.CrossEntropyLoss(), dir="Conf_B1_after_train_A-test", cpu=True)
print(acc)

tensor(9.0377)


#### for test all and plot test curve

In [None]:
run(B_train_dl , test_all, newModel1, index=3, epochs=20, lr=0.0001, cm=True, dir="CM1")

## Second Way

In [153]:
new_model2 = transfer_learning(my_model, freeze=True)

#### weight after transfer

In [154]:
new_model2.conv2_1.weight

Parameter containing:
tensor([[[[-4.1152e-02,  2.7552e-02,  2.3232e-02],
          [-5.0019e-02, -3.7431e-02,  2.7063e-02],
          [ 4.6284e-03, -4.1788e-02,  3.4469e-02]],

         [[ 1.0118e-03, -3.1182e-02, -1.6825e-02],
          [-3.5135e-02, -4.1477e-02,  4.8241e-03],
          [ 1.6821e-02, -2.0892e-02, -2.0368e-02]],

         [[ 2.6120e-02, -2.7894e-02, -1.2092e-02],
          [ 1.3028e-02, -3.5441e-02, -1.2177e-02],
          [-3.2591e-02, -2.6870e-02,  3.9030e-02]],

         ...,

         [[ 2.6396e-02,  3.2726e-02,  7.4274e-03],
          [-8.5307e-03, -2.7306e-02, -3.1041e-02],
          [-1.1025e-02, -1.9495e-02,  2.3737e-02]],

         [[-2.6593e-02,  3.9382e-02, -8.3937e-03],
          [-3.1704e-02, -6.0318e-03,  5.3831e-03],
          [-4.7803e-02, -4.3978e-02,  5.3610e-06]],

         [[-3.1691e-02,  1.3603e-03, -2.7139e-03],
          [ 1.4907e-02, -1.9542e-03,  8.6123e-05],
          [-7.5084e-03, -3.1214e-03,  1.3781e-02]]],


        [[[-1.7927e-02,  1.9769

In [155]:
new_model2.fc1.weight

Parameter containing:
tensor([[-0.0148, -0.0265, -0.0253,  ...,  0.0032,  0.0118, -0.0070],
        [ 0.0125,  0.0121,  0.0150,  ..., -0.0289,  0.0135,  0.0117],
        [-0.0210, -0.0359, -0.0215,  ...,  0.0106,  0.0034,  0.0037],
        ...,
        [-0.0057,  0.0136,  0.0093,  ...,  0.0092,  0.0100, -0.0023],
        [ 0.0122,  0.0152,  0.0062,  ...,  0.0132,  0.0029,  0.0114],
        [-0.0082,  0.0127,  0.0095,  ..., -0.0137,  0.0123, -0.0109]],
       requires_grad=True)

### Train on B

In [156]:
run(B_train_dl, B_test_dl, new_model2, index=4, epochs=40, lr=0.0001)

  0%|          | 0/40 [00:00<?, ?it/s]

train:  tensor(2.) 3.7205376625061035


  2%|▎         | 1/40 [00:01<01:05,  1.68s/it]

test:  tensor(2.8958) 3.452285575866699
train:  tensor(6.) 3.5804147720336914


  5%|▌         | 2/40 [00:03<01:00,  1.60s/it]

test:  tensor(3.8610) 3.494399166107178
train:  tensor(9.) 3.398089647293091


  8%|▊         | 3/40 [00:04<00:58,  1.58s/it]

test:  tensor(5.0193) 3.332865333557129
train:  tensor(10.0000) 3.326887845993042


 10%|█         | 4/40 [00:06<00:56,  1.57s/it]

test:  tensor(5.5985) 3.223168420791626
train:  tensor(10.0000) 3.206458806991577


 12%|█▎        | 5/40 [00:07<00:54,  1.56s/it]

test:  tensor(6.3707) 3.17399582862854
train:  tensor(12.) 3.0391805171966553


 15%|█▌        | 6/40 [00:09<00:52,  1.55s/it]

test:  tensor(7.1429) 3.087889385223389
train:  tensor(12.) 2.9905405044555664


 18%|█▊        | 7/40 [00:10<00:51,  1.56s/it]

test:  tensor(7.9151) 3.1470765113830566
train:  tensor(19.) 2.8615405559539795


 20%|██        | 8/40 [00:12<00:49,  1.56s/it]

test:  tensor(9.6525) 3.0680490493774415
train:  tensor(25.) 2.76902437210083


 22%|██▎       | 9/40 [00:14<00:48,  1.56s/it]

test:  tensor(11.3900) 2.936030626296997
train:  tensor(27.0000) 2.6793978214263916


 25%|██▌       | 10/40 [00:15<00:46,  1.56s/it]

test:  tensor(13.7066) 2.8933103561401365
train:  tensor(34.) 2.570929765701294


 28%|██▊       | 11/40 [00:17<00:44,  1.55s/it]

test:  tensor(15.6371) 2.8251014709472657
train:  tensor(37.) 2.5132410526275635


 30%|███       | 12/40 [00:18<00:43,  1.55s/it]

test:  tensor(17.7606) 2.7907477378845216
train:  tensor(43.0000) 2.416928768157959


 32%|███▎      | 13/40 [00:20<00:41,  1.55s/it]

test:  tensor(20.8494) 2.795509147644043
train:  tensor(53.0000) 2.3729865550994873


 35%|███▌      | 14/40 [00:21<00:40,  1.56s/it]

test:  tensor(22.5869) 2.705788516998291
train:  tensor(57.) 2.2269628047943115


 38%|███▊      | 15/40 [00:23<00:38,  1.55s/it]

test:  tensor(24.3243) 2.711904001235962
train:  tensor(58.) 2.177650213241577


 40%|████      | 16/40 [00:24<00:37,  1.56s/it]

test:  tensor(26.0618) 2.66068754196167
train:  tensor(65.) 2.128427505493164


 42%|████▎     | 17/40 [00:26<00:35,  1.56s/it]

test:  tensor(26.8340) 2.612897348403931
train:  tensor(69.) 2.038872003555298


 45%|████▌     | 18/40 [00:28<00:34,  1.55s/it]

test:  tensor(28.7645) 2.612540102005005
train:  tensor(70.) 1.9858895540237427


 48%|████▊     | 19/40 [00:29<00:32,  1.55s/it]

test:  tensor(30.3089) 2.483223628997803
train:  tensor(72.) 1.9363417625427246


 50%|█████     | 20/40 [00:31<00:30,  1.55s/it]

test:  tensor(31.2741) 2.5079375743865966
train:  tensor(75.) 1.8485203981399536


 52%|█████▎    | 21/40 [00:32<00:29,  1.54s/it]

test:  tensor(31.0811) 2.487429714202881
train:  tensor(74.) 1.8320926427841187


 55%|█████▌    | 22/40 [00:34<00:27,  1.55s/it]

test:  tensor(32.6255) 2.4301040172576904
train:  tensor(83.) 1.7352464199066162


 57%|█████▊    | 23/40 [00:35<00:26,  1.55s/it]

test:  tensor(32.8185) 2.4164044857025146
train:  tensor(87.) 1.6739602088928223


 60%|██████    | 24/40 [00:37<00:24,  1.55s/it]

test:  tensor(33.3977) 2.4468355655670164
train:  tensor(86.0000) 1.5998294353485107


 62%|██████▎   | 25/40 [00:38<00:23,  1.54s/it]

test:  tensor(33.5907) 2.4870423316955566
train:  tensor(87.) 1.5730395317077637


 65%|██████▌   | 26/40 [00:40<00:21,  1.55s/it]

test:  tensor(35.1351) 2.413618040084839
train:  tensor(86.0000) 1.5237082242965698


 68%|██████▊   | 27/40 [00:42<00:20,  1.56s/it]

test:  tensor(35.1351) 2.2807663679122925
train:  tensor(91.) 1.4669337272644043


 70%|███████   | 28/40 [00:43<00:18,  1.56s/it]

test:  tensor(35.3282) 2.31563663482666
train:  tensor(94.) 1.394215703010559


 72%|███████▎  | 29/40 [00:45<00:17,  1.56s/it]

test:  tensor(35.7143) 2.3173510074615478
train:  tensor(89.) 1.3701772689819336


 75%|███████▌  | 30/40 [00:46<00:15,  1.57s/it]

test:  tensor(36.1004) 2.3146987438201903
train:  tensor(95.) 1.3139219284057617


 78%|███████▊  | 31/40 [00:48<00:14,  1.56s/it]

test:  tensor(36.1004) 2.284729814529419
train:  tensor(92.0000) 1.2769190073013306


 80%|████████  | 32/40 [00:49<00:12,  1.56s/it]

test:  tensor(36.6795) 2.3058104515075684
train:  tensor(92.0000) 1.2597386837005615


 82%|████████▎ | 33/40 [00:51<00:10,  1.55s/it]

test:  tensor(37.6448) 2.3090025424957275
train:  tensor(96.) 1.1979601383209229


 85%|████████▌ | 34/40 [00:52<00:09,  1.56s/it]

test:  tensor(38.2239) 2.18477201461792
train:  tensor(96.) 1.1361417770385742


 88%|████████▊ | 35/40 [00:54<00:07,  1.56s/it]

test:  tensor(38.6100) 2.2373184680938722
train:  tensor(95.) 1.1066319942474365


 90%|█████████ | 36/40 [00:56<00:06,  1.56s/it]

test:  tensor(38.4170) 2.278513526916504
train:  tensor(93.) 1.0943814516067505


 92%|█████████▎| 37/40 [00:57<00:04,  1.56s/it]

test:  tensor(38.9961) 2.2265177726745606
train:  tensor(92.0000) 1.0491719245910645


 95%|█████████▌| 38/40 [00:59<00:03,  1.56s/it]

test:  tensor(39.1892) 2.106191897392273
train:  tensor(94.) 1.0388538837432861


 98%|█████████▊| 39/40 [01:00<00:01,  1.56s/it]

test:  tensor(39.1892) 2.230818843841553
train:  tensor(97.) 0.9560943841934204


100%|██████████| 40/40 [01:02<00:00,  1.56s/it]

test:  tensor(40.3475) 2.117531442642212





<Figure size 1400x600 with 0 Axes>

#### weights after train 

In [157]:
new_model2.conv2_1.weight

Parameter containing:
tensor([[[[-4.1152e-02,  2.7552e-02,  2.3232e-02],
          [-5.0019e-02, -3.7431e-02,  2.7063e-02],
          [ 4.6284e-03, -4.1788e-02,  3.4469e-02]],

         [[ 1.0118e-03, -3.1182e-02, -1.6825e-02],
          [-3.5135e-02, -4.1477e-02,  4.8241e-03],
          [ 1.6821e-02, -2.0892e-02, -2.0368e-02]],

         [[ 2.6120e-02, -2.7894e-02, -1.2092e-02],
          [ 1.3028e-02, -3.5441e-02, -1.2177e-02],
          [-3.2591e-02, -2.6870e-02,  3.9030e-02]],

         ...,

         [[ 2.6396e-02,  3.2726e-02,  7.4274e-03],
          [-8.5307e-03, -2.7306e-02, -3.1041e-02],
          [-1.1025e-02, -1.9495e-02,  2.3737e-02]],

         [[-2.6593e-02,  3.9382e-02, -8.3937e-03],
          [-3.1704e-02, -6.0318e-03,  5.3831e-03],
          [-4.7803e-02, -4.3978e-02,  5.3610e-06]],

         [[-3.1691e-02,  1.3603e-03, -2.7139e-03],
          [ 1.4907e-02, -1.9542e-03,  8.6123e-05],
          [-7.5084e-03, -3.1214e-03,  1.3781e-02]]],


        [[[-1.7927e-02,  1.9769

In [158]:
new_model2.fc1.weight

Parameter containing:
tensor([[-0.0170, -0.0284, -0.0274,  ...,  0.0014,  0.0098, -0.0091],
        [ 0.0108,  0.0102,  0.0129,  ..., -0.0309,  0.0114,  0.0098],
        [-0.0230, -0.0378, -0.0232,  ...,  0.0085,  0.0016,  0.0020],
        ...,
        [-0.0024,  0.0171,  0.0120,  ...,  0.0056,  0.0121, -0.0062],
        [ 0.0083,  0.0113,  0.0027,  ...,  0.0114,  0.0043,  0.0089],
        [-0.0056,  0.0159,  0.0126,  ..., -0.0106,  0.0153, -0.0079]],
       device='cuda:0', requires_grad=True)

#### test on A after train on B with freezing all layers except fc

In [159]:
nm1 = new_model2.to("cpu")
epoch_acc, epoch_loss = test_model(nm1, A_test_dl, loss_fn=nn.CrossEntropyLoss(), dir="Conf_B2_after_train_A-test", cpu=True)
print(epoch_acc)

tensor(47.3992)


#### for test all and plot test curve

In [None]:
run(B_train_dl , test_all, new_model2, index=5, epochs=20, lr=0.0001, cm=True, dir="CM2")

## Third Way

In [160]:
new_model3 = transfer_learning(my_model, freeze=True)

##### weights after transfer

In [161]:
new_model3.conv2_1.weight

Parameter containing:
tensor([[[[-4.1152e-02,  2.7552e-02,  2.3232e-02],
          [-5.0019e-02, -3.7431e-02,  2.7063e-02],
          [ 4.6284e-03, -4.1788e-02,  3.4469e-02]],

         [[ 1.0118e-03, -3.1182e-02, -1.6825e-02],
          [-3.5135e-02, -4.1477e-02,  4.8241e-03],
          [ 1.6821e-02, -2.0892e-02, -2.0368e-02]],

         [[ 2.6120e-02, -2.7894e-02, -1.2092e-02],
          [ 1.3028e-02, -3.5441e-02, -1.2177e-02],
          [-3.2591e-02, -2.6870e-02,  3.9030e-02]],

         ...,

         [[ 2.6396e-02,  3.2726e-02,  7.4274e-03],
          [-8.5307e-03, -2.7306e-02, -3.1041e-02],
          [-1.1025e-02, -1.9495e-02,  2.3737e-02]],

         [[-2.6593e-02,  3.9382e-02, -8.3937e-03],
          [-3.1704e-02, -6.0318e-03,  5.3831e-03],
          [-4.7803e-02, -4.3978e-02,  5.3610e-06]],

         [[-3.1691e-02,  1.3603e-03, -2.7139e-03],
          [ 1.4907e-02, -1.9542e-03,  8.6123e-05],
          [-7.5084e-03, -3.1214e-03,  1.3781e-02]]],


        [[[-1.7927e-02,  1.9769

In [162]:
new_model3.fc1.weight

Parameter containing:
tensor([[-0.0148, -0.0265, -0.0253,  ...,  0.0032,  0.0118, -0.0070],
        [ 0.0125,  0.0121,  0.0150,  ..., -0.0289,  0.0135,  0.0117],
        [-0.0210, -0.0359, -0.0215,  ...,  0.0106,  0.0034,  0.0037],
        ...,
        [-0.0082,  0.0088, -0.0012,  ...,  0.0147, -0.0134,  0.0134],
        [ 0.0123,  0.0021,  0.0028,  ..., -0.0142,  0.0143, -0.0069],
        [ 0.0066,  0.0025, -0.0051,  ..., -0.0004,  0.0048,  0.0136]],
       requires_grad=True)

### train third model 

In [163]:
run(B_train_dl, B_test_dl, new_model3, index=21, epochs=50, lr=0.0001, freeze_fc=True)

  0%|          | 0/50 [00:00<?, ?it/s]

train:  tensor(4.) 3.7633426189422607


  2%|▏         | 1/50 [00:01<01:27,  1.78s/it]

test:  tensor(2.7027) 3.6670021533966066
train:  tensor(5.0000) 3.632528781890869


  4%|▍         | 2/50 [00:03<01:18,  1.65s/it]

test:  tensor(2.8958) 3.5799498558044434
train:  tensor(6.) 3.5751843452453613


  6%|▌         | 3/50 [00:04<01:15,  1.62s/it]

test:  tensor(3.4749) 3.504531145095825
train:  tensor(4.) 3.494377851486206


  8%|▊         | 4/50 [00:06<01:13,  1.60s/it]

test:  tensor(4.2471) 3.3851584434509276
train:  tensor(6.) 3.415786027908325


 10%|█         | 5/50 [00:08<01:11,  1.59s/it]

test:  tensor(4.8263) 3.430712270736694
train:  tensor(10.0000) 3.3008315563201904


 12%|█▏        | 6/50 [00:09<01:10,  1.59s/it]

test:  tensor(5.0193) 3.3640034198760986
train:  tensor(8.) 3.2823526859283447


 14%|█▍        | 7/50 [00:11<01:08,  1.58s/it]

test:  tensor(5.5985) 3.2963839530944825
train:  tensor(13.) 3.0998106002807617


 16%|█▌        | 8/50 [00:12<01:06,  1.59s/it]

test:  tensor(6.5637) 3.3062639236450195
train:  tensor(17.) 3.0686142444610596


 18%|█▊        | 9/50 [00:14<01:04,  1.58s/it]

test:  tensor(7.7220) 3.1331418991088866
train:  tensor(16.) 2.9680557250976562


 20%|██        | 10/50 [00:15<01:03,  1.58s/it]

test:  tensor(8.6873) 3.165080451965332
train:  tensor(19.) 2.931849479675293


 22%|██▏       | 11/50 [00:17<01:01,  1.57s/it]

test:  tensor(9.2664) 3.2517873287200927
train:  tensor(26.) 2.844069719314575


 24%|██▍       | 12/50 [00:19<00:59,  1.57s/it]

test:  tensor(10.0386) 3.1183118343353273
train:  tensor(24.) 2.7437124252319336


 26%|██▌       | 13/50 [00:20<00:57,  1.56s/it]

test:  tensor(10.4247) 3.195771265029907
train:  tensor(28.) 2.690646171569824


 28%|██▊       | 14/50 [00:22<00:56,  1.56s/it]

test:  tensor(11.1969) 3.1293596744537355
train:  tensor(28.) 2.6335134506225586


 30%|███       | 15/50 [00:23<00:54,  1.57s/it]

test:  tensor(11.3900) 3.0079551219940184
train:  tensor(29.) 2.586172103881836


 32%|███▏      | 16/50 [00:25<00:53,  1.56s/it]

test:  tensor(12.5483) 2.963643503189087
train:  tensor(34.) 2.472047805786133


 34%|███▍      | 17/50 [00:26<00:51,  1.56s/it]

test:  tensor(13.7066) 3.052238178253174
train:  tensor(40.0000) 2.4068422317504883


 36%|███▌      | 18/50 [00:28<00:50,  1.57s/it]

test:  tensor(14.0927) 2.988345670700073
train:  tensor(48.) 2.306119680404663


 38%|███▊      | 19/50 [00:30<00:48,  1.57s/it]

test:  tensor(15.8301) 2.9665701389312744
train:  tensor(52.) 2.273719549179077


 40%|████      | 20/50 [00:31<00:46,  1.56s/it]

test:  tensor(17.7606) 2.781074333190918
train:  tensor(48.) 2.186476469039917


 42%|████▏     | 21/50 [00:33<00:45,  1.57s/it]

test:  tensor(18.5328) 2.7545422077178956
train:  tensor(52.) 2.149423599243164


 44%|████▍     | 22/50 [00:34<00:43,  1.57s/it]

test:  tensor(19.6911) 2.7425349712371827
train:  tensor(53.0000) 2.132357358932495


 46%|████▌     | 23/50 [00:36<00:42,  1.57s/it]

test:  tensor(20.0772) 2.718258810043335
train:  tensor(67.0000) 1.9725656509399414


 48%|████▊     | 24/50 [00:37<00:40,  1.57s/it]

test:  tensor(21.0425) 2.7590802669525147
train:  tensor(67.0000) 1.9564793109893799


 50%|█████     | 25/50 [00:39<00:39,  1.57s/it]

test:  tensor(22.0077) 2.6372655868530273
train:  tensor(63.) 1.9150112867355347


 52%|█████▏    | 26/50 [00:41<00:37,  1.57s/it]

test:  tensor(22.9730) 2.643167781829834
train:  tensor(64.) 1.8913525342941284


 54%|█████▍    | 27/50 [00:42<00:36,  1.57s/it]

test:  tensor(23.3591) 2.6880624771118162
train:  tensor(73.0000) 1.7579536437988281


 56%|█████▌    | 28/50 [00:44<00:34,  1.57s/it]

test:  tensor(24.1313) 2.615520191192627
train:  tensor(69.) 1.7693709135055542


 58%|█████▊    | 29/50 [00:45<00:32,  1.57s/it]

test:  tensor(24.7104) 2.625764989852905
train:  tensor(72.) 1.6677772998809814


 60%|██████    | 30/50 [00:47<00:31,  1.58s/it]

test:  tensor(26.4479) 2.632811164855957
train:  tensor(74.) 1.6390678882598877


 62%|██████▏   | 31/50 [00:48<00:30,  1.58s/it]

test:  tensor(27.0270) 2.5469154834747316
train:  tensor(77.) 1.5809913873672485


 64%|██████▍   | 32/50 [00:50<00:28,  1.59s/it]

test:  tensor(27.7992) 2.5070052623748778
train:  tensor(75.) 1.5374822616577148


 66%|██████▌   | 33/50 [00:52<00:27,  1.60s/it]

test:  tensor(28.3784) 2.53343186378479
train:  tensor(78.) 1.5074254274368286


 68%|██████▊   | 34/50 [00:53<00:25,  1.59s/it]

test:  tensor(29.5367) 2.4397231101989747
train:  tensor(83.) 1.430891513824463


 70%|███████   | 35/50 [00:55<00:23,  1.59s/it]

test:  tensor(30.3089) 2.4404768466949465
train:  tensor(80.0000) 1.4062846899032593


 72%|███████▏  | 36/50 [00:56<00:22,  1.59s/it]

test:  tensor(30.6950) 2.4890787601470947
train:  tensor(77.) 1.3621617555618286


 74%|███████▍  | 37/50 [00:58<00:20,  1.57s/it]

test:  tensor(31.0811) 2.4156789779663086
train:  tensor(81.) 1.31526517868042


 76%|███████▌  | 38/50 [00:59<00:18,  1.57s/it]

test:  tensor(31.8533) 2.435846948623657
train:  tensor(85.) 1.2504937648773193


 78%|███████▊  | 39/50 [01:01<00:17,  1.58s/it]

test:  tensor(32.2394) 2.37402548789978
train:  tensor(82.) 1.240095615386963


 80%|████████  | 40/50 [01:03<00:15,  1.58s/it]

test:  tensor(33.2046) 2.3569030284881594
train:  tensor(81.) 1.194214105606079


 82%|████████▏ | 41/50 [01:04<00:14,  1.57s/it]

test:  tensor(33.2046) 2.3724224090576174
train:  tensor(82.) 1.1641690731048584


 84%|████████▍ | 42/50 [01:06<00:12,  1.57s/it]

test:  tensor(33.2046) 2.2094027757644654
train:  tensor(82.) 1.1094721555709839


 86%|████████▌ | 43/50 [01:07<00:10,  1.56s/it]

test:  tensor(33.3977) 2.379234027862549
train:  tensor(89.) 1.072819709777832


 88%|████████▊ | 44/50 [01:09<00:09,  1.56s/it]

test:  tensor(33.3977) 2.4557605743408204
train:  tensor(87.) 1.0788905620574951


 90%|█████████ | 45/50 [01:10<00:07,  1.57s/it]

test:  tensor(33.3977) 2.339921474456787
train:  tensor(88.) 0.9993700981140137


 92%|█████████▏| 46/50 [01:12<00:06,  1.56s/it]

test:  tensor(33.7838) 2.390343952178955
train:  tensor(91.) 0.9773685336112976


 94%|█████████▍| 47/50 [01:14<00:04,  1.56s/it]

test:  tensor(34.1699) 2.327823352813721
train:  tensor(91.) 0.9353179335594177


 96%|█████████▌| 48/50 [01:15<00:03,  1.57s/it]

test:  tensor(35.1351) 2.1798791170120237
train:  tensor(88.) 0.935096025466919


 98%|█████████▊| 49/50 [01:17<00:01,  1.56s/it]

test:  tensor(35.5212) 2.228701448440552
train:  tensor(88.) 0.9179799556732178


100%|██████████| 50/50 [01:18<00:00,  1.57s/it]

test:  tensor(35.5212) 2.335897159576416





<Figure size 1400x600 with 0 Axes>

##### check the weights after train the model

In [164]:
new_model3.conv2_1.weight

Parameter containing:
tensor([[[[-4.1152e-02,  2.7552e-02,  2.3232e-02],
          [-5.0019e-02, -3.7431e-02,  2.7063e-02],
          [ 4.6284e-03, -4.1788e-02,  3.4469e-02]],

         [[ 1.0118e-03, -3.1182e-02, -1.6825e-02],
          [-3.5135e-02, -4.1477e-02,  4.8241e-03],
          [ 1.6821e-02, -2.0892e-02, -2.0368e-02]],

         [[ 2.6120e-02, -2.7894e-02, -1.2092e-02],
          [ 1.3028e-02, -3.5441e-02, -1.2177e-02],
          [-3.2591e-02, -2.6870e-02,  3.9030e-02]],

         ...,

         [[ 2.6396e-02,  3.2726e-02,  7.4274e-03],
          [-8.5307e-03, -2.7306e-02, -3.1041e-02],
          [-1.1025e-02, -1.9495e-02,  2.3737e-02]],

         [[-2.6593e-02,  3.9382e-02, -8.3937e-03],
          [-3.1704e-02, -6.0318e-03,  5.3831e-03],
          [-4.7803e-02, -4.3978e-02,  5.3610e-06]],

         [[-3.1691e-02,  1.3603e-03, -2.7139e-03],
          [ 1.4907e-02, -1.9542e-03,  8.6123e-05],
          [-7.5084e-03, -3.1214e-03,  1.3781e-02]]],


        [[[-1.7927e-02,  1.9769

In [165]:
new_model3.fc1.weight

Parameter containing:
tensor([[-1.4840e-02, -2.6468e-02, -2.5326e-02,  ...,  3.2259e-03,
          1.1839e-02, -7.0226e-03],
        [ 1.2509e-02,  1.2070e-02,  1.4986e-02,  ..., -2.8877e-02,
          1.3546e-02,  1.1703e-02],
        [-2.1009e-02, -3.5923e-02, -2.1500e-02,  ...,  1.0603e-02,
          3.4309e-03,  3.6793e-03],
        ...,
        [-3.8068e-03,  1.3182e-02,  2.8409e-03,  ...,  1.0496e-02,
         -9.5307e-03,  8.7829e-03],
        [ 7.4267e-03, -2.7958e-03,  6.6207e-05,  ..., -1.3116e-02,
          1.8094e-02, -9.1179e-03],
        [ 1.0495e-02,  6.7059e-03, -1.1667e-03,  ...,  3.6929e-03,
          8.8386e-03,  1.8029e-02]], device='cuda:0', requires_grad=True)

#### for test all and plot test curve

In [None]:
run(B_train_dl , test_all, new_model3, index=155, epochs=20, lr=0.0001, cm=True, freeze_fc=True, dir="CM3")