In [1]:

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import timm
import torchsummary
import imageio
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.cuda.amp import GradScaler, autocast

from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split, KFold

import sys
sys.path.append('../DataLoader')

from dataloader import SunImageDataset

from torch.func import stack_module_state
from torch.func import vmap

from lightning.fabric import Fabric

In [2]:
torch.set_float32_matmul_precision('medium')
fabric = Fabric(accelerator='cuda', devices=1, precision="bf16-mixed")
fabric.launch()
print(fabric.device)

Using bfloat16 Automatic Mixed Precision (AMP)


cuda:0


In [3]:
# Hyper-parameters
input_size = 224*224
hidden_size = 166
num_epochs = 20
batch_size = 2
learning_rate = 0.001
dropout = 0.5
# dropout = 0.6990787087509548

In [4]:
dataset = SunImageDataset(csv_file="D:\\dataset.csv", offset=0, transform=transforms.ToTensor())
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

# # Without Validation Set
# trainset, testset = torch.utils.data.Subset(dataset, range(train_size)), torch.utils.data.Subset(dataset, range(train_size, len(dataset)))
# trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True, num_workers=10)

# trainloader = fabric.setup_dataloaders(trainloader)

# With Validation Set
# Split dataset into training and test sets
train_indices, test_indices = train_test_split(range(len(dataset)), test_size=0.2, shuffle=False)

# Further split training set into training and validation sets
train_indices, val_indices = train_test_split(train_indices, test_size=0.25, shuffle=False)  # 0.25 x 0.8 = 0.2

trainset = torch.utils.data.Subset(dataset, train_indices)
valset = torch.utils.data.Subset(dataset, val_indices)
testset = torch.utils.data.Subset(dataset, test_indices)

trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True, num_workers=10)
valloader = torch.utils.data.DataLoader(dataset=valset, batch_size=batch_size, shuffle=False, num_workers=10)

trainloader = fabric.setup_dataloaders(trainloader)
valloader = fabric.setup_dataloaders(valloader)


# # Get a batch of training data
# dataiter = iter(trainloader)
# images, labels = next(dataiter)
# images = torch.stack(images)
# print(images.shape)
# print(labels.shape)

# print(images)
# print(labels)

In [5]:

class GmiSwinTransformer(nn.Module):
    def __init__(self, hidden_size: int):
        super(GmiSwinTransformer, self).__init__()
        
        # Batch normalization for 3 channels
        self.bn = nn.BatchNorm2d(3)
        
        # Initialize Swin Transformer
        self.pretrained_model = timm.create_model(
            'swin_base_patch4_window7_224',
            pretrained=True,
            num_classes=hidden_size
        )
        
        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(hidden_size*10, hidden_size),
            nn.Dropout(p=dropout),  # Added dropout probability
            nn.LeakyReLU(),
            nn.Linear(hidden_size, 1),
            nn.LeakyReLU()
        )
    
    def forward(self, images) -> torch.Tensor:
        """
        Batch should be in format:
        {
            'images': torch.FloatTensor((10, 1, 224, 224))
        }
        """
        
        batch_size = images.shape[0]
        
        images = images.reshape(-1, 1, 224, 224)
        images = torch.cat([images, images, images], dim=1)
        normalized_images = self.bn(images)
        features = self.pretrained_model(normalized_images)
        image_features = features.view(batch_size, -1)
        
        output = self.fc(image_features)
        return output

# Initialize model
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = GmiSwinTransformer(hidden_size=hidden_size).to(device)
model = GmiSwinTransformer(hidden_size=hidden_size)

# print(torchsummary.summary(model, (10, 1, 224, 224)))

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model, optimizer = fabric.setup(model, optimizer)
model.train()

_FabricModule(
  (_forward_module): GmiSwinTransformer(
    (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pretrained_model): SwinTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
        (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (layers): Sequential(
        (0): SwinTransformerStage(
          (downsample): Identity()
          (blocks): Sequential(
            (0): SwinTransformerBlock(
              (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attn): WindowAttention(
                (qkv): Linear(in_features=128, out_features=384, bias=True)
                (attn_drop): Dropout(p=0.0, inplace=False)
                (proj): Linear(in_features=128, out_features=128, bias=True)
                (proj_drop): Dropout(p=0.0, inplace=False)
                (softmax): Softmax(dim=-1)
              )
              (drop_pa

In [6]:
torch.cuda.empty_cache()

In [None]:
# Training the model
n_total_steps = len(trainloader)
avg_train_loss_over_epochs = []
avg_val_loss_over_epochs = []

for epoch in range(num_epochs):
    train_losses = []
    val_losses = []
    
    # Training loop
    model.train()
    for i, (images, labels) in tqdm(enumerate(trainloader), desc="Training Progress", total=len(trainloader)):
        # Move images and labels to device
        images = torch.stack(images).float()
        images = images.permute(1, 0, 2, 3, 4)  # Change shape to [5, 10, 1, 224, 224]
        labels = labels.float()

        # Forward pass with autograd
        outputs = model(images).squeeze(1)
        loss = criterion(outputs, labels)
        tqdm.write(f"Epoch: {epoch+1}, Index: {i}, Loss: {loss.item():.4f}")

        # Backward pass and optimization
        optimizer.zero_grad()
        fabric.backward(loss)
        optimizer.step()
        # Store the loss
        train_losses.append(loss.item())
    
    # Store the average training loss for this epoch
    avg_train_loss_over_epochs.append(sum(train_losses) / len(train_losses))
    
    # Validation loop
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(valloader, desc="Validation Progress"):
            images = torch.stack(images).float()
            images = images.permute(1, 0, 2, 3, 4)
            labels = labels.float()
            outputs = model(images).squeeze(1)
            loss = criterion(outputs, labels)
            print("Validation Loss: ", loss.item())
            val_losses.append(loss.item())
    
    # Store the average validation loss for this epoch
    avg_val_loss_over_epochs.append(sum(val_losses) / len(val_losses))

# Plot loss over epochs
plt.figure()
plt.plot(range(1, num_epochs + 1), avg_train_loss_over_epochs, label='Average Training Loss', marker='o')
plt.plot(range(1, num_epochs + 1), avg_val_loss_over_epochs, label='Average Validation Loss', marker='o')
plt.xticks(range(1, num_epochs + 1))  # Ensure x-axis includes all epoch numbers
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.legend()
plt.title('Training and Validation Loss Over Epochs')
plt.show()

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 1, Index: 0, Loss: 2.9206
Epoch: 1, Index: 1, Loss: 0.2089
Epoch: 1, Index: 2, Loss: 62.4478
Epoch: 1, Index: 3, Loss: 7.9942
Epoch: 1, Index: 4, Loss: 7.0289
Epoch: 1, Index: 5, Loss: 5.4620
Epoch: 1, Index: 6, Loss: 0.4657
Epoch: 1, Index: 7, Loss: 0.7289
Epoch: 1, Index: 8, Loss: 2.1132
Epoch: 1, Index: 9, Loss: 0.7197
Epoch: 1, Index: 10, Loss: 4.2164
Epoch: 1, Index: 11, Loss: 5.0883
Epoch: 1, Index: 12, Loss: 13.0442
Epoch: 1, Index: 13, Loss: 9.5281
Epoch: 1, Index: 14, Loss: 6.6037
Epoch: 1, Index: 15, Loss: 1.5012
Epoch: 1, Index: 16, Loss: 2.2566
Epoch: 1, Index: 17, Loss: 2.5071
Epoch: 1, Index: 18, Loss: 7.7257
Epoch: 1, Index: 19, Loss: 1.7571
Epoch: 1, Index: 20, Loss: 0.3626
Epoch: 1, Index: 21, Loss: 1.6751
Epoch: 1, Index: 22, Loss: 1.9915
Epoch: 1, Index: 23, Loss: 0.0313
Epoch: 1, Index: 24, Loss: 0.1284
Epoch: 1, Index: 25, Loss: 0.7214
Epoch: 1, Index: 26, Loss: 3.5963
Epoch: 1, Index: 27, Loss: 0.7338
Epoch: 1, Index: 28, Loss: 1.7465
Epoch: 1, Index: 29, L

Validation Progress:   0%|          | 0/103 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 2, Index: 0, Loss: 2.3417
Epoch: 2, Index: 1, Loss: 1.6401
Epoch: 2, Index: 2, Loss: 2.9971
Epoch: 2, Index: 3, Loss: 0.2854
Epoch: 2, Index: 4, Loss: 2.7117
Epoch: 2, Index: 5, Loss: 0.2868
Epoch: 2, Index: 6, Loss: 0.4082
Epoch: 2, Index: 7, Loss: 4.3517
Epoch: 2, Index: 8, Loss: 1.2372
Epoch: 2, Index: 9, Loss: 0.1097
Epoch: 2, Index: 10, Loss: 1.5996
Epoch: 2, Index: 11, Loss: 1.7532
Epoch: 2, Index: 12, Loss: 2.9294
Epoch: 2, Index: 13, Loss: 1.5399
Epoch: 2, Index: 14, Loss: 3.8688
Epoch: 2, Index: 15, Loss: 2.6604
Epoch: 2, Index: 16, Loss: 0.1204
Epoch: 2, Index: 17, Loss: 0.2820
Epoch: 2, Index: 18, Loss: 1.5088
Epoch: 2, Index: 19, Loss: 3.9357
Epoch: 2, Index: 20, Loss: 1.5324
Epoch: 2, Index: 21, Loss: 1.8409
Epoch: 2, Index: 22, Loss: 3.8735
Epoch: 2, Index: 23, Loss: 2.0441
Epoch: 2, Index: 24, Loss: 3.2741
Epoch: 2, Index: 25, Loss: 2.0183
Epoch: 2, Index: 26, Loss: 5.3922
Epoch: 2, Index: 27, Loss: 4.5136
Epoch: 2, Index: 28, Loss: 0.4267
Epoch: 2, Index: 29, Los

Validation Progress:   0%|          | 0/103 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 3, Index: 0, Loss: 5.1784
Epoch: 3, Index: 1, Loss: 0.2419
Epoch: 3, Index: 2, Loss: 0.0980
Epoch: 3, Index: 3, Loss: 0.7094
Epoch: 3, Index: 4, Loss: 0.0331
Epoch: 3, Index: 5, Loss: 2.3116
Epoch: 3, Index: 6, Loss: 0.3737
Epoch: 3, Index: 7, Loss: 1.6654
Epoch: 3, Index: 8, Loss: 1.5970
Epoch: 3, Index: 9, Loss: 4.6975
Epoch: 3, Index: 10, Loss: 0.8623
Epoch: 3, Index: 11, Loss: 0.1042
Epoch: 3, Index: 12, Loss: 3.1381
Epoch: 3, Index: 13, Loss: 0.6921
Epoch: 3, Index: 14, Loss: 0.7000
Epoch: 3, Index: 15, Loss: 0.7720
Epoch: 3, Index: 16, Loss: 3.5910
Epoch: 3, Index: 17, Loss: 2.0761
Epoch: 3, Index: 18, Loss: 2.0630
Epoch: 3, Index: 19, Loss: 0.2008
Epoch: 3, Index: 20, Loss: 3.6258
Epoch: 3, Index: 21, Loss: 8.7638
Epoch: 3, Index: 22, Loss: 2.7267
Epoch: 3, Index: 23, Loss: 0.3770
Epoch: 3, Index: 24, Loss: 3.1955
Epoch: 3, Index: 25, Loss: 2.5672
Epoch: 3, Index: 26, Loss: 0.0196
Epoch: 3, Index: 27, Loss: 2.4736
Epoch: 3, Index: 28, Loss: 5.5082
Epoch: 3, Index: 29, Los

Validation Progress:   0%|          | 0/103 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 4, Index: 0, Loss: 1.6948
Epoch: 4, Index: 1, Loss: 3.7188
Epoch: 4, Index: 2, Loss: 0.0054
Epoch: 4, Index: 3, Loss: 0.0287
Epoch: 4, Index: 4, Loss: 1.6843
Epoch: 4, Index: 5, Loss: 4.7751
Epoch: 4, Index: 6, Loss: 6.6367
Epoch: 4, Index: 7, Loss: 2.6002
Epoch: 4, Index: 8, Loss: 2.0439
Epoch: 4, Index: 9, Loss: 3.2442
Epoch: 4, Index: 10, Loss: 1.2725
Epoch: 4, Index: 11, Loss: 3.6556
Epoch: 4, Index: 12, Loss: 8.8363
Epoch: 4, Index: 13, Loss: 0.4889
Epoch: 4, Index: 14, Loss: 0.9268
Epoch: 4, Index: 15, Loss: 8.2804
Epoch: 4, Index: 16, Loss: 0.4838
Epoch: 4, Index: 17, Loss: 0.6868
Epoch: 4, Index: 18, Loss: 5.1850
Epoch: 4, Index: 19, Loss: 0.0539
Epoch: 4, Index: 20, Loss: 4.1885
Epoch: 4, Index: 21, Loss: 6.2979
Epoch: 4, Index: 22, Loss: 1.1214
Epoch: 4, Index: 23, Loss: 2.3709
Epoch: 4, Index: 24, Loss: 5.1689
Epoch: 4, Index: 25, Loss: 0.6014
Epoch: 4, Index: 26, Loss: 0.0369
Epoch: 4, Index: 27, Loss: 0.8073
Epoch: 4, Index: 28, Loss: 3.3417
Epoch: 4, Index: 29, Los

Validation Progress:   0%|          | 0/103 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 5, Index: 0, Loss: 1.6971
Epoch: 5, Index: 1, Loss: 1.8371
Epoch: 5, Index: 2, Loss: 6.9075
Epoch: 5, Index: 3, Loss: 6.5866
Epoch: 5, Index: 4, Loss: 0.1578
Epoch: 5, Index: 5, Loss: 0.7530
Epoch: 5, Index: 6, Loss: 2.7142
Epoch: 5, Index: 7, Loss: 0.4849
Epoch: 5, Index: 8, Loss: 0.3609
Epoch: 5, Index: 9, Loss: 2.6755
Epoch: 5, Index: 10, Loss: 1.6379
Epoch: 5, Index: 11, Loss: 2.9836
Epoch: 5, Index: 12, Loss: 3.4302
Epoch: 5, Index: 13, Loss: 0.1156
Epoch: 5, Index: 14, Loss: 3.3738
Epoch: 5, Index: 15, Loss: 0.3347
Epoch: 5, Index: 16, Loss: 0.4339
Epoch: 5, Index: 17, Loss: 1.1016
Epoch: 5, Index: 18, Loss: 2.3961
Epoch: 5, Index: 19, Loss: 0.0066
Epoch: 5, Index: 20, Loss: 6.9380
Epoch: 5, Index: 21, Loss: 0.6217
Epoch: 5, Index: 22, Loss: 2.0806
Epoch: 5, Index: 23, Loss: 1.0089
Epoch: 5, Index: 24, Loss: 4.0221
Epoch: 5, Index: 25, Loss: 0.8730
Epoch: 5, Index: 26, Loss: 2.3777
Epoch: 5, Index: 27, Loss: 1.9983
Epoch: 5, Index: 28, Loss: 5.3189
Epoch: 5, Index: 29, Los

Validation Progress:   0%|          | 0/103 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 6, Index: 0, Loss: 11.6248
Epoch: 6, Index: 1, Loss: 0.5223
Epoch: 6, Index: 2, Loss: 0.1906
Epoch: 6, Index: 3, Loss: 5.1794
Epoch: 6, Index: 4, Loss: 0.9500
Epoch: 6, Index: 5, Loss: 1.1112
Epoch: 6, Index: 6, Loss: 8.2098
Epoch: 6, Index: 7, Loss: 0.8993
Epoch: 6, Index: 8, Loss: 2.9895
Epoch: 6, Index: 9, Loss: 4.3113
Epoch: 6, Index: 10, Loss: 0.2907
Epoch: 6, Index: 11, Loss: 1.0113
Epoch: 6, Index: 12, Loss: 0.3373
Epoch: 6, Index: 13, Loss: 4.8199
Epoch: 6, Index: 14, Loss: 0.6498
Epoch: 6, Index: 15, Loss: 0.0421
Epoch: 6, Index: 16, Loss: 1.6266
Epoch: 6, Index: 17, Loss: 0.0629
Epoch: 6, Index: 18, Loss: 3.2322
Epoch: 6, Index: 19, Loss: 9.4700
Epoch: 6, Index: 20, Loss: 1.4298
Epoch: 6, Index: 21, Loss: 1.5257
Epoch: 6, Index: 22, Loss: 0.6000
Epoch: 6, Index: 23, Loss: 5.3228
Epoch: 6, Index: 24, Loss: 0.3837
Epoch: 6, Index: 25, Loss: 2.0518
Epoch: 6, Index: 26, Loss: 0.9053
Epoch: 6, Index: 27, Loss: 3.3880
Epoch: 6, Index: 28, Loss: 2.7427
Epoch: 6, Index: 29, Lo

Validation Progress:   0%|          | 0/103 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 7, Index: 0, Loss: 0.1933
Epoch: 7, Index: 1, Loss: 0.0273
Epoch: 7, Index: 2, Loss: 0.6413
Epoch: 7, Index: 3, Loss: 1.5105
Epoch: 7, Index: 4, Loss: 0.0505
Epoch: 7, Index: 5, Loss: 3.5033
Epoch: 7, Index: 6, Loss: 1.2255
Epoch: 7, Index: 7, Loss: 2.9277
Epoch: 7, Index: 8, Loss: 8.9506
Epoch: 7, Index: 9, Loss: 1.4321
Epoch: 7, Index: 10, Loss: 3.0192
Epoch: 7, Index: 11, Loss: 1.0162
Epoch: 7, Index: 12, Loss: 0.6929
Epoch: 7, Index: 13, Loss: 4.9030
Epoch: 7, Index: 14, Loss: 1.0826
Epoch: 7, Index: 15, Loss: 1.6206
Epoch: 7, Index: 16, Loss: 1.3593
Epoch: 7, Index: 17, Loss: 0.2608
Epoch: 7, Index: 18, Loss: 1.9328
Epoch: 7, Index: 19, Loss: 1.3511
Epoch: 7, Index: 20, Loss: 0.4600
Epoch: 7, Index: 21, Loss: 0.3714
Epoch: 7, Index: 22, Loss: 8.7972
Epoch: 7, Index: 23, Loss: 2.2158
Epoch: 7, Index: 24, Loss: 2.2256
Epoch: 7, Index: 25, Loss: 2.7156
Epoch: 7, Index: 26, Loss: 4.7522
Epoch: 7, Index: 27, Loss: 4.3483
Epoch: 7, Index: 28, Loss: 0.1716
Epoch: 7, Index: 29, Los

Validation Progress:   0%|          | 0/103 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 8, Index: 0, Loss: 3.5552
Epoch: 8, Index: 1, Loss: 0.0135
Epoch: 8, Index: 2, Loss: 3.9035
Epoch: 8, Index: 3, Loss: 0.9293
Epoch: 8, Index: 4, Loss: 0.2636
Epoch: 8, Index: 5, Loss: 1.1982
Epoch: 8, Index: 6, Loss: 0.0637
Epoch: 8, Index: 7, Loss: 2.2893
Epoch: 8, Index: 8, Loss: 0.1531
Epoch: 8, Index: 9, Loss: 1.0965
Epoch: 8, Index: 10, Loss: 1.6870
Epoch: 8, Index: 11, Loss: 1.7338
Epoch: 8, Index: 12, Loss: 0.5416
Epoch: 8, Index: 13, Loss: 6.2569
Epoch: 8, Index: 14, Loss: 3.5696
Epoch: 8, Index: 15, Loss: 3.7975
Epoch: 8, Index: 16, Loss: 2.2557
Epoch: 8, Index: 17, Loss: 1.8187
Epoch: 8, Index: 18, Loss: 2.0146
Epoch: 8, Index: 19, Loss: 2.7413
Epoch: 8, Index: 20, Loss: 0.4814
Epoch: 8, Index: 21, Loss: 4.9832
Epoch: 8, Index: 22, Loss: 1.2779
Epoch: 8, Index: 23, Loss: 12.2485
Epoch: 8, Index: 24, Loss: 0.5613
Epoch: 8, Index: 25, Loss: 5.3556
Epoch: 8, Index: 26, Loss: 0.1004
Epoch: 8, Index: 27, Loss: 7.3297
Epoch: 8, Index: 28, Loss: 1.1380
Epoch: 8, Index: 29, Lo

Validation Progress:   0%|          | 0/103 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 9, Index: 0, Loss: 1.4561
Epoch: 9, Index: 1, Loss: 1.0739
Epoch: 9, Index: 2, Loss: 3.3407
Epoch: 9, Index: 3, Loss: 0.3828
Epoch: 9, Index: 4, Loss: 5.9343
Epoch: 9, Index: 5, Loss: 0.6112
Epoch: 9, Index: 6, Loss: 4.1691
Epoch: 9, Index: 7, Loss: 0.1783
Epoch: 9, Index: 8, Loss: 0.0359
Epoch: 9, Index: 9, Loss: 1.1282
Epoch: 9, Index: 10, Loss: 0.6554
Epoch: 9, Index: 11, Loss: 4.1333
Epoch: 9, Index: 12, Loss: 0.1173
Epoch: 9, Index: 13, Loss: 0.1071
Epoch: 9, Index: 14, Loss: 0.3042
Epoch: 9, Index: 15, Loss: 0.2783
Epoch: 9, Index: 16, Loss: 5.0992
Epoch: 9, Index: 17, Loss: 1.4588
Epoch: 9, Index: 18, Loss: 0.1195
Epoch: 9, Index: 19, Loss: 0.7658
Epoch: 9, Index: 20, Loss: 1.5260
Epoch: 9, Index: 21, Loss: 1.1013
Epoch: 9, Index: 22, Loss: 0.7585
Epoch: 9, Index: 23, Loss: 1.7943
Epoch: 9, Index: 24, Loss: 0.6666
Epoch: 9, Index: 25, Loss: 3.7988
Epoch: 9, Index: 26, Loss: 0.2462
Epoch: 9, Index: 27, Loss: 3.9702
Epoch: 9, Index: 28, Loss: 1.5302
Epoch: 9, Index: 29, Los

Validation Progress:   0%|          | 0/103 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 10, Index: 0, Loss: 0.2700
Epoch: 10, Index: 1, Loss: 3.7155
Epoch: 10, Index: 2, Loss: 0.3566
Epoch: 10, Index: 3, Loss: 7.6319
Epoch: 10, Index: 4, Loss: 2.6371
Epoch: 10, Index: 5, Loss: 16.5236
Epoch: 10, Index: 6, Loss: 3.2116
Epoch: 10, Index: 7, Loss: 0.2724
Epoch: 10, Index: 8, Loss: 0.2444
Epoch: 10, Index: 9, Loss: 0.2560
Epoch: 10, Index: 10, Loss: 2.0009
Epoch: 10, Index: 11, Loss: 0.2403
Epoch: 10, Index: 12, Loss: 3.2734
Epoch: 10, Index: 13, Loss: 1.7839
Epoch: 10, Index: 14, Loss: 0.1503
Epoch: 10, Index: 15, Loss: 3.1935
Epoch: 10, Index: 16, Loss: 2.7153
Epoch: 10, Index: 17, Loss: 3.8039
Epoch: 10, Index: 18, Loss: 0.3594
Epoch: 10, Index: 19, Loss: 1.2317
Epoch: 10, Index: 20, Loss: 1.3043
Epoch: 10, Index: 21, Loss: 2.9155
Epoch: 10, Index: 22, Loss: 4.2536
Epoch: 10, Index: 23, Loss: 0.4853
Epoch: 10, Index: 24, Loss: 1.1201
Epoch: 10, Index: 25, Loss: 0.5473
Epoch: 10, Index: 26, Loss: 0.8124
Epoch: 10, Index: 27, Loss: 0.7939
Epoch: 10, Index: 28, Loss: 2

Validation Progress:   0%|          | 0/103 [00:00<?, ?it/s]

Training Progress:   0%|          | 0/307 [00:00<?, ?it/s]

Epoch: 11, Index: 0, Loss: 1.6976
Epoch: 11, Index: 1, Loss: 0.3148
Epoch: 11, Index: 2, Loss: 1.9809
Epoch: 11, Index: 3, Loss: 1.2043
Epoch: 11, Index: 4, Loss: 3.5435
Epoch: 11, Index: 5, Loss: 8.4871
Epoch: 11, Index: 6, Loss: 0.3380
Epoch: 11, Index: 7, Loss: 1.0701
Epoch: 11, Index: 8, Loss: 2.8066
Epoch: 11, Index: 9, Loss: 1.8173
Epoch: 11, Index: 10, Loss: 0.7958
Epoch: 11, Index: 11, Loss: 0.3423
Epoch: 11, Index: 12, Loss: 0.9439
Epoch: 11, Index: 13, Loss: 0.0762
Epoch: 11, Index: 14, Loss: 0.4518
Epoch: 11, Index: 15, Loss: 1.6553
Epoch: 11, Index: 16, Loss: 1.3376
Epoch: 11, Index: 17, Loss: 1.7786
Epoch: 11, Index: 18, Loss: 0.7237
Epoch: 11, Index: 19, Loss: 0.2008
Epoch: 11, Index: 20, Loss: 0.3137
Epoch: 11, Index: 21, Loss: 0.2318
Epoch: 11, Index: 22, Loss: 1.2239
Epoch: 11, Index: 23, Loss: 0.4982
Epoch: 11, Index: 24, Loss: 3.3868
Epoch: 11, Index: 25, Loss: 1.5991
Epoch: 11, Index: 26, Loss: 0.4283
Epoch: 11, Index: 27, Loss: 0.7744
Epoch: 11, Index: 28, Loss: 4.

In [None]:
# Test the model
model.eval()
testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False, num_workers=10)
testloader = fabric.setup_dataloaders(testloader)
test_losses = []

with torch.no_grad():
    for images, labels in tqdm(testloader, desc="Testing Progress"):
        images = torch.stack(images).float()
        images = images.permute(1, 0, 2, 3, 4)
        labels = labels.float()
        outputs = model(images).squeeze(1)
        loss = criterion(outputs, labels)
        test_losses.append(loss.item())
        print("Test loss", test_losses)

avg_test_loss = sum(test_losses) / len(test_losses)
print(f'Average test loss: {avg_test_loss:.4f}')

Testing Progress:   0%|          | 0/103 [00:00<?, ?it/s]