# 

In [1]:
import torch
from torchvision import transforms
# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BATCH_SIZE = 64
IMG_SIZE = (128,128)
PATCH_SIZE = (8,8)
# Number of patch in one sequence
NUM_PATCH = int((IMG_SIZE[0] / PATCH_SIZE[0])**2)
# Number of patch out after feature extracting
T = int(NUM_PATCH / 4)
# Number of attention head
NUM_H = 4
# Number of Transformer Block
NUM_TR = 11

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
my_transform = transforms.Compose([
                                    transforms.Resize((IMG_SIZE)),
                                    transforms.RandAugment(),
                                    transforms.ToTensor()
])

In [5]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - torch.Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
    return x

In [6]:
class AttentionBlock(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads,
                                          dropout=dropout)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        


    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x

In [7]:
class Transpose(nn.Module):
    def __init__(self, dim1, dim2):
        super().__init__()
        self.dim1 = dim1
        self.dim2 = dim2

    def forward(self, x):
        return x.transpose(self.dim1, self.dim2)


In [8]:
class VisionTransformer(nn.Module):

    def __init__(self, embed_dim, hidden_dim,n_feature, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        # self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)
        # self.input_layer = FeatureExtractor(n_feature,embed_dim)
        self.input_layer = nn.Sequential(
            nn.Conv3d(num_channels , n_feature*3 , kernel_size = 3 , padding=1 , stride=1),
            nn.ReLU(),
            # nn.Conv3d(n_feature , n_feature*2 , kernel_size = 3 , padding=1 , stride=1),
            # nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2,1,1) , stride=(2,1,1) , padding=(0,0,0)),
            
            nn.Conv3d(n_feature*3 , n_feature*6 , 3 , padding=1 , stride=1),
            nn.ReLU(),
            # nn.Conv3d(n_feature*4 , n_feature*8 , 3 , padding=1 , stride=1),
            # nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2,1,1) , stride=(2,1,1) , padding=(0,0,0)),
            nn.Conv3d(n_feature*6 , n_feature*8 , 3 , padding=1 , stride=1),
            
            
            Transpose(2,1),
            nn.Flatten(2,-1),
            nn.Linear(n_feature*8*PATCH_SIZE[0]*PATCH_SIZE[0] ,embed_dim )
            
            
        )
        self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
            # torch.nn.Softmax(dim=1)
        )
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1,1+int(NUM_PATCH / 4),embed_dim))


    def forward(self, x):
        # Preprocess input
        x = img_to_patch(x, self.patch_size , flatten_channels=False).transpose(1,2)
        # print(x.shape)
        B, _, _ , _ , _= x.shape
        T = int(NUM_PATCH / 4)
        x = self.input_layer(x)
        # print(x.shape)
        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]
        out = self.mlp_head(cls)
        return out

In [9]:
AmirModel = VisionTransformer(embed_dim = 256 , 
                            hidden_dim = 512,
                            n_feature=8 ,
                            num_channels = 3, 
                            num_heads = NUM_H , 
                            num_layers = NUM_TR ,
                            num_classes=4,
                            patch_size=PATCH_SIZE[0] ,
                            num_patches=NUM_PATCH , 
                            dropout=0.1 ).to(device)

**************************************************************************************

In [10]:
# Compile the model 
torch.set_float32_matmul_precision('high')
AmirModel = torch.compile(AmirModel)

In [11]:
!pwd

/app/ieee/code/Classification-Clustering/Classification


In [12]:
# AmirModel.load_state_dict(torch.load("../TrackExp/" + name + ".pth"))
AmirModel.load_state_dict(torch.load("../../../../Ablation_Study/TrackExp/BestWeight_Part_1_Epoch_3_TrainLoss_0.044724611829818246_TestLoss_0.016080412089830787TrainAcc_0.9831583453496756_Testacc_0.9951923076923077.pth"))

<All keys matched successfully>

In [13]:
train_dir = "../../../ICIP_denoised_data/Data_2/train/"
test_dir = "../../../ICIP_denoised_data/Data_2/test/"

In [54]:
from torch.utils.data import default_collate
from torchvision.transforms import v2

cutmix = v2.CutMix(num_classes=2,alpha=0.4)
mixup = v2.MixUp(num_classes=2,alpha=0.5)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
def collate_fn(batch):
    return cutmix_or_mixup(*default_collate(batch))


In [55]:
# Setup dataloaders
from going_modular.going_modular import data_setup, engine
from helper_functions import download_data, set_seeds, plot_loss_curves

train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir,
                                                                                test_dir=test_dir,
                                                                                transform=my_transform,
                                                                                batch_size=BATCH_SIZE,
                                                                                num_workers = 8,
                                                                                collate_fn=collate_fn) # Could increase if we had more samples, such as here: https://arxiv.org/abs/2205.01580 (there are other improvements there too...)


In [56]:
class_names

['1', '2']

In [57]:
next(iter(train_dataloader))[0].shape ,next(iter(train_dataloader))[1]

(torch.Size([64, 3, 128, 128]),
 tensor([[1.0000, 0.0000],
         [0.1121, 0.8879],
         [0.8879, 0.1121],
         [1.0000, 0.0000],
         [0.1121, 0.8879],
         [0.8879, 0.1121],
         [0.1121, 0.8879],
         [0.8879, 0.1121],
         [1.0000, 0.0000],
         [1.0000, 0.0000],
         [1.0000, 0.0000],
         [0.1121, 0.8879],
         [0.0000, 1.0000],
         [0.8879, 0.1121],
         [1.0000, 0.0000],
         [0.1121, 0.8879],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.8879, 0.1121],
         [0.1121, 0.8879],
         [0.0000, 1.0000],
         [0.8879, 0.1121],
         [1.0000, 0.0000],
         [1.0000, 0.0000],
         [1.0000, 0.0000],
         [1.0000, 0.0000],
         [1.0000, 0.0000],
         [0.1121, 0.8879],
         [0.0000, 1.0000],
         [0.8879, 0.1121],
         [0.1121, 0.8879],
         [0.8879, 0.1121],
         [0.1121, 0.8879],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.8879, 0.112

In [58]:
print(f"train_dir: {train_dir}")
print(f"train_dir: {test_dir}")
print(f"len train_dataloader: {len(train_dataloader)}")
print(f"len test_dataloader: {len(test_dataloader)}")
print(f"device: {device}")

train_dir: ../../../ICIP_denoised_data/Data_2/train/
train_dir: ../../../ICIP_denoised_data/Data_2/test/
len train_dataloader: 115
len test_dataloader: 29
device: cuda


In [59]:
# Create optimizer and loss function
optimizer = torch.optim.SGD(AmirModel.parameters() , 0.001)
loss_fn = torch.nn.CrossEntropyLoss()

In [60]:
AmirModel = nn.Sequential(AmirModel,nn.Linear(4,2)).to(device)

In [61]:
from going_modular.going_modular import engine
part = "ICIP_final_denoise"
# Train the classifier head of the pretrained ViT feature extractor model
set_seeds()
x = engine.train(model=AmirModel,
                                      train_dataloader=train_dataloader,
                                      test_dataloader=test_dataloader,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      epochs=50,
                                      Pivot=65,
                                      Part=part,
                                      save_weights=True,
                                      device=device)

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,
  2%|██▎                                                                                                                   | 1/50 [01:16<1:02:04, 76.01s/it]

Epoch: 1 | train_loss: 0.7782 | train_acc: 0.5873 | test_loss: 0.7623 | test_acc: 0.5361


  4%|████▊                                                                                                                   | 2/50 [01:40<36:25, 45.53s/it]

Epoch: 2 | train_loss: 0.6646 | train_acc: 0.6270 | test_loss: 0.7225 | test_acc: 0.5857


  6%|███████▏                                                                                                                | 3/50 [02:04<28:08, 35.92s/it]

Epoch: 3 | train_loss: 0.6580 | train_acc: 0.6263 | test_loss: 0.7220 | test_acc: 0.5652


  8%|█████████▌                                                                                                              | 4/50 [02:29<24:03, 31.38s/it]

Epoch: 4 | train_loss: 0.6430 | train_acc: 0.6418 | test_loss: 0.7077 | test_acc: 0.5889


 10%|████████████                                                                                                            | 5/50 [02:53<21:39, 28.88s/it]

Epoch: 5 | train_loss: 0.6302 | train_acc: 0.6645 | test_loss: 0.6852 | test_acc: 0.6115


 12%|██████████████▍                                                                                                         | 6/50 [03:17<20:03, 27.35s/it]

Epoch: 6 | train_loss: 0.6199 | train_acc: 0.6715 | test_loss: 0.6653 | test_acc: 0.6417


 14%|████████████████▊                                                                                                       | 7/50 [03:42<18:54, 26.38s/it]

Epoch: 7 | train_loss: 0.6216 | train_acc: 0.6777 | test_loss: 0.6861 | test_acc: 0.6137


 16%|███████████████████▏                                                                                                    | 8/50 [04:06<18:02, 25.77s/it]

Epoch: 8 | train_loss: 0.6073 | train_acc: 0.6901 | test_loss: 0.6747 | test_acc: 0.6315


 18%|█████████████████████▌                                                                                                  | 9/50 [04:31<17:18, 25.33s/it]

Epoch: 9 | train_loss: 0.5991 | train_acc: 0.6987 | test_loss: 0.6757 | test_acc: 0.6261


 20%|███████████████████████▊                                                                                               | 10/50 [04:55<16:42, 25.05s/it]

Epoch: 10 | train_loss: 0.6070 | train_acc: 0.6945 | test_loss: 0.6464 | test_acc: 0.6455


 22%|██████████████████████████▏                                                                                            | 11/50 [05:20<16:10, 24.88s/it]

Epoch: 11 | train_loss: 0.6021 | train_acc: 0.7000 | test_loss: 0.6573 | test_acc: 0.6476


 24%|████████████████████████████▌                                                                                          | 12/50 [05:44<15:40, 24.75s/it]

Epoch: 12 | train_loss: 0.5828 | train_acc: 0.7113 | test_loss: 0.6662 | test_acc: 0.6433


 26%|██████████████████████████████▉                                                                                        | 13/50 [06:08<15:12, 24.66s/it]

Epoch: 13 | train_loss: 0.5867 | train_acc: 0.7093 | test_loss: 0.6499 | test_acc: 0.6616
[INFO] Saving model to: TrackExp/BestWeight_Part_ICIP_final_denoise_Epoch_12_TrainLoss_0.5866567520991616_TestLoss_0.6498781273077274TrainAcc_0.7092950767263426_Testacc_0.6616379310344828.pth


 28%|█████████████████████████████████▎                                                                                     | 14/50 [06:33<14:44, 24.57s/it]

Epoch: 14 | train_loss: 0.5939 | train_acc: 0.7022 | test_loss: 0.6707 | test_acc: 0.6245


 30%|███████████████████████████████████▋                                                                                   | 15/50 [06:57<14:17, 24.51s/it]

Epoch: 15 | train_loss: 0.5891 | train_acc: 0.7125 | test_loss: 0.6528 | test_acc: 0.6600


 32%|██████████████████████████████████████                                                                                 | 16/50 [07:22<13:51, 24.47s/it]

Epoch: 16 | train_loss: 0.5798 | train_acc: 0.7169 | test_loss: 0.6512 | test_acc: 0.6492


 34%|████████████████████████████████████████▍                                                                              | 17/50 [07:46<13:26, 24.44s/it]

Epoch: 17 | train_loss: 0.5782 | train_acc: 0.7214 | test_loss: 0.6581 | test_acc: 0.6325


 36%|██████████████████████████████████████████▊                                                                            | 18/50 [08:10<13:01, 24.43s/it]

Epoch: 18 | train_loss: 0.5926 | train_acc: 0.7104 | test_loss: 0.6442 | test_acc: 0.6654
[INFO] Saving model to: TrackExp/BestWeight_Part_ICIP_final_denoise_Epoch_17_TrainLoss_0.5926210504511128_TestLoss_0.6442239659613577TrainAcc_0.7104219948849105_Testacc_0.6654094827586207.pth


 38%|█████████████████████████████████████████████▏                                                                         | 19/50 [08:35<12:37, 24.42s/it]

Epoch: 19 | train_loss: 0.5667 | train_acc: 0.7359 | test_loss: 0.6306 | test_acc: 0.6686
[INFO] Saving model to: TrackExp/BestWeight_Part_ICIP_final_denoise_Epoch_18_TrainLoss_0.5667477965354919_TestLoss_0.63064814333258TrainAcc_0.7358615728900256_Testacc_0.6686422413793104.pth


 40%|███████████████████████████████████████████████▌                                                                       | 20/50 [08:59<12:12, 24.42s/it]

Epoch: 20 | train_loss: 0.5882 | train_acc: 0.7202 | test_loss: 0.6592 | test_acc: 0.6406


 42%|█████████████████████████████████████████████████▉                                                                     | 21/50 [09:24<11:47, 24.41s/it]

Epoch: 21 | train_loss: 0.5536 | train_acc: 0.7517 | test_loss: 0.6551 | test_acc: 0.6466


 44%|████████████████████████████████████████████████████▎                                                                  | 22/50 [09:48<11:23, 24.39s/it]

Epoch: 22 | train_loss: 0.5700 | train_acc: 0.7292 | test_loss: 0.6598 | test_acc: 0.6449


 46%|██████████████████████████████████████████████████████▋                                                                | 23/50 [10:12<10:58, 24.40s/it]

Epoch: 23 | train_loss: 0.5728 | train_acc: 0.7385 | test_loss: 0.6469 | test_acc: 0.6374


 48%|█████████████████████████████████████████████████████████                                                              | 24/50 [10:37<10:34, 24.40s/it]

Epoch: 24 | train_loss: 0.5544 | train_acc: 0.7422 | test_loss: 0.6597 | test_acc: 0.6374


 50%|███████████████████████████████████████████████████████████▌                                                           | 25/50 [11:01<10:09, 24.40s/it]

Epoch: 25 | train_loss: 0.5643 | train_acc: 0.7395 | test_loss: 0.6613 | test_acc: 0.6509


 52%|█████████████████████████████████████████████████████████████▉                                                         | 26/50 [11:26<09:45, 24.39s/it]

Epoch: 26 | train_loss: 0.5674 | train_acc: 0.7361 | test_loss: 0.6476 | test_acc: 0.6525


 54%|████████████████████████████████████████████████████████████████▎                                                      | 27/50 [11:50<09:23, 24.49s/it]

Epoch: 27 | train_loss: 0.5585 | train_acc: 0.7405 | test_loss: 0.6528 | test_acc: 0.6449


 56%|██████████████████████████████████████████████████████████████████▋                                                    | 28/50 [12:15<09:03, 24.70s/it]

Epoch: 28 | train_loss: 0.5642 | train_acc: 0.7410 | test_loss: 0.6431 | test_acc: 0.6730
[INFO] Saving model to: TrackExp/BestWeight_Part_ICIP_final_denoise_Epoch_27_TrainLoss_0.5641963264216547_TestLoss_0.6431121343168719TrainAcc_0.7409686700767264_Testacc_0.6729525862068966.pth


 58%|█████████████████████████████████████████████████████████████████████                                                  | 29/50 [12:40<08:39, 24.72s/it]

Epoch: 29 | train_loss: 0.5560 | train_acc: 0.7526 | test_loss: 0.6629 | test_acc: 0.6482


 60%|███████████████████████████████████████████████████████████████████████▍                                               | 30/50 [13:05<08:12, 24.62s/it]

Epoch: 30 | train_loss: 0.5403 | train_acc: 0.7607 | test_loss: 0.6908 | test_acc: 0.6412


 62%|█████████████████████████████████████████████████████████████████████████▊                                             | 31/50 [13:29<07:46, 24.56s/it]

Epoch: 31 | train_loss: 0.5483 | train_acc: 0.7482 | test_loss: 0.6874 | test_acc: 0.6336


 64%|████████████████████████████████████████████████████████████████████████████▏                                          | 32/50 [13:53<07:21, 24.51s/it]

Epoch: 32 | train_loss: 0.5529 | train_acc: 0.7531 | test_loss: 0.7199 | test_acc: 0.6320


 66%|██████████████████████████████████████████████████████████████████████████████▌                                        | 33/50 [14:18<06:56, 24.50s/it]

Epoch: 33 | train_loss: 0.5398 | train_acc: 0.7597 | test_loss: 0.7027 | test_acc: 0.6536


 68%|████████████████████████████████████████████████████████████████████████████████▉                                      | 34/50 [14:42<06:31, 24.47s/it]

Epoch: 34 | train_loss: 0.5376 | train_acc: 0.7631 | test_loss: 0.6629 | test_acc: 0.6455


 70%|███████████████████████████████████████████████████████████████████████████████████▎                                   | 35/50 [15:07<06:06, 24.45s/it]

Epoch: 35 | train_loss: 0.5568 | train_acc: 0.7496 | test_loss: 0.6648 | test_acc: 0.6390


 72%|█████████████████████████████████████████████████████████████████████████████████████▋                                 | 36/50 [15:31<05:41, 24.43s/it]

Epoch: 36 | train_loss: 0.5432 | train_acc: 0.7580 | test_loss: 0.6451 | test_acc: 0.6713


 74%|████████████████████████████████████████████████████████████████████████████████████████                               | 37/50 [15:55<05:17, 24.41s/it]

Epoch: 37 | train_loss: 0.5442 | train_acc: 0.7647 | test_loss: 0.6872 | test_acc: 0.6498


 76%|██████████████████████████████████████████████████████████████████████████████████████████▍                            | 38/50 [16:20<04:52, 24.40s/it]

Epoch: 38 | train_loss: 0.5433 | train_acc: 0.7614 | test_loss: 0.6648 | test_acc: 0.6627


 78%|████████████████████████████████████████████████████████████████████████████████████████████▊                          | 39/50 [16:44<04:28, 24.40s/it]

Epoch: 39 | train_loss: 0.5116 | train_acc: 0.7678 | test_loss: 0.6529 | test_acc: 0.6719


 80%|███████████████████████████████████████████████████████████████████████████████████████████████▏                       | 40/50 [17:09<04:03, 24.40s/it]

Epoch: 40 | train_loss: 0.5328 | train_acc: 0.7744 | test_loss: 0.7163 | test_acc: 0.6536


 82%|█████████████████████████████████████████████████████████████████████████████████████████████████▌                     | 41/50 [17:33<03:39, 24.43s/it]

Epoch: 41 | train_loss: 0.5519 | train_acc: 0.7492 | test_loss: 0.6622 | test_acc: 0.6767
[INFO] Saving model to: TrackExp/BestWeight_Part_ICIP_final_denoise_Epoch_40_TrainLoss_0.5518940212933914_TestLoss_0.6621803785192555TrainAcc_0.7492087595907928_Testacc_0.6767241379310345.pth


 84%|███████████████████████████████████████████████████████████████████████████████████████████████████▉                   | 42/50 [17:58<03:15, 24.45s/it]

Epoch: 42 | train_loss: 0.5403 | train_acc: 0.7570 | test_loss: 0.6410 | test_acc: 0.6622


 86%|██████████████████████████████████████████████████████████████████████████████████████████████████████▎                | 43/50 [18:22<02:51, 24.44s/it]

Epoch: 43 | train_loss: 0.5426 | train_acc: 0.7598 | test_loss: 0.6461 | test_acc: 0.6670


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████▋              | 44/50 [18:48<02:29, 24.95s/it]

Epoch: 44 | train_loss: 0.5335 | train_acc: 0.7567 | test_loss: 0.6637 | test_acc: 0.6433


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████            | 45/50 [19:13<02:04, 24.98s/it]

Epoch: 45 | train_loss: 0.5299 | train_acc: 0.7741 | test_loss: 0.7027 | test_acc: 0.6476


 92%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▍         | 46/50 [19:39<01:40, 25.13s/it]

Epoch: 46 | train_loss: 0.5436 | train_acc: 0.7603 | test_loss: 0.6987 | test_acc: 0.6412


 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▊       | 47/50 [20:05<01:16, 25.40s/it]

Epoch: 47 | train_loss: 0.5236 | train_acc: 0.7764 | test_loss: 0.7108 | test_acc: 0.6584


 96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏    | 48/50 [20:30<00:50, 25.27s/it]

Epoch: 48 | train_loss: 0.5495 | train_acc: 0.7575 | test_loss: 0.6692 | test_acc: 0.6460


 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌  | 49/50 [21:01<00:26, 26.96s/it]

Epoch: 49 | train_loss: 0.5427 | train_acc: 0.7584 | test_loss: 0.6918 | test_acc: 0.6336


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [21:37<00:00, 25.95s/it]

Epoch: 50 | train_loss: 0.5312 | train_acc: 0.7735 | test_loss: 0.7358 | test_acc: 0.6331





## Save model

In [None]:
# # Save the model
from going_modular.going_modular import utils

utils.save_model(model=AmirModel,
                 target_dir="./TrackExp",
                 model_name=name2 + ".pth")

## Load model weights

In [22]:
name2 = "TrackExp/BestWeight_Part_19_Epoch_2_TrainLoss_0.01639417985006382_TestLoss_0.027976555515579093TrainAcc_0.9941451149425288_Testacc_0.9951171875.pth"
AmirModel.load_state_dict(torch.load(name2))

<All keys matched successfully>