In [None]:
import numpy as np

x = np.load('/kaggle/input/brain-eeg-spectrograms/EEG_Spectrograms/1007356722.npy')
print(x.shape)

In [1]:
import numpy as np 
import pandas as pd 
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.datasets import ImageFolder
from PIL import Image
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from torch.amp import autocast, GradScaler
from tqdm import tqdm
from torchvision.transforms import v2

In [2]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class EEGSpectrogramDataset(Dataset):
    def __init__(self, eeg_ids, spectrogram_dict, label_dict):
        self.eeg_ids = eeg_ids
        self.spectrogram_dict = spectrogram_dict
        self.label_dict = label_dict

        self.transform = v2.Compose([
            v2.Resize((224, 224)),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def __len__(self):
        return len(self.eeg_ids)
    
    def __getitem__(self, idx):
        eeg_id = self.eeg_ids[idx]
        spectrogram = self.spectrogram_dict[eeg_id]
        
        label_str = self.label_dict[eeg_id]
        label = self._label_to_index(label_str)

        stacked = np.vstack([spectrogram[..., i] for i in range(4)])[..., np.newaxis]
        stacked = torch.tensor(stacked, dtype=torch.float32).permute(2, 0, 1)
    
        # spectrogram = torch.tensor(spectrogram, dtype=torch.float32).permute(2, 0, 1)

        # spectrogram = F.interpolate(spectrogram.unsqueeze(0), 
        #                          size=(224, 224), 
        #                          mode='bilinear',
        #                          align_corners=False).squeeze(0)
        
        # if self.transform:
            # spectrogram = self.transform(spectrogram)

        rgb = torch.cat([stacked, stacked, stacked], dim=0)
        
        # Resize to 224x224 using bilinear interpolation
        rgb = F.interpolate(rgb.unsqueeze(0), 
                          size=(224, 224), 
                          mode='bilinear',
                          align_corners=False).squeeze(0)
        
        # Apply transforms
        rgb = self.transform(rgb)
        
        label = torch.tensor(label, dtype=torch.long)
        
        return rgb, label
    
    def _label_to_index(self, label_str):
        label_mapping = {
            'Seizure': 0,
            'LPD': 1,
            'GPD': 2,
            'LRDA': 3,
            'GRDA': 4,
            'Other': 5
        }
        return label_mapping.get(label_str, 5)

spectrogram_dict = np.load('/kaggle/input/brain-eeg-spectrograms/eeg_specs.npy', allow_pickle=True).item()

train_df = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/train.csv')
label_dict = dict(zip(train_df['eeg_id'], train_df['expert_consensus']))

common_eeg_ids = [eeg_id for eeg_id in spectrogram_dict.keys() if eeg_id in label_dict]

print(f"Total EEG IDs in spectrogram dict: {len(spectrogram_dict)}")
print(f"Total EEG IDs in labels: {len(label_dict)}")
print(f"Common EEG IDs with both: {len(common_eeg_ids)}")
print(f"Sample spectrogram shape: {spectrogram_dict[common_eeg_ids[0]].shape}")

Total EEG IDs in spectrogram dict: 17089
Total EEG IDs in labels: 17089
Common EEG IDs with both: 17089
Sample spectrogram shape: (128, 256, 4)


In [13]:
train_eeg_ids, test_eeg_ids = train_test_split(
    common_eeg_ids,
    test_size=0.3,
    random_state=42,
    stratify=[label_dict[eeg_id] for eeg_id in common_eeg_ids]
)

train_dataset = EEGSpectrogramDataset(train_eeg_ids, spectrogram_dict, label_dict)
test_dataset = EEGSpectrogramDataset(test_eeg_ids, spectrogram_dict, label_dict)

batch_size = 16
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print("\nDataset successfully created:")
print(f"Total samples: {len(common_eeg_ids)}")
print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Batch size: {batch_size}")
print(f"Sample spectrogram shape: {train_dataset[0][0].shape}")


Dataset successfully created:
Total samples: 17089
Train samples: 11962
Test samples: 5127
Batch size: 16
Sample spectrogram shape: torch.Size([3, 224, 224])


In [14]:
num_classes = 6 

model = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-7)

os.makedirs("/kaggle/working/saved_models3", exist_ok=True)

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

model = model.to(device)

In [6]:
device

device(type='cuda')

In [16]:
class EarlyStopping:
    def __init__(self, patience=3, verbose=False):
        self.patience = patience
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False
        self.verbose = verbose

    def __call__(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping: {self.counter}/{self.patience} without improvement.")
            if self.counter >= self.patience:
                self.early_stop = True

test size: 0.3, le = 2e-4

In [17]:
# Set this to True to use FP16 training
fp16 = True
scaler = GradScaler(enabled=fp16)

num_epochs = 20
best_acc = 0.0
best_test_loss = float('inf')
use_accuracy_for_best = False  # Set to False to save best model by val loss

early_stopper = EarlyStopping(patience=3, verbose=True)

for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        with autocast(device_type='cuda', enabled=fp16):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    scheduler.step()
    
    train_acc = correct / total
    avg_train_loss = running_loss / len(train_loader)
    print(f"\nEpoch [{epoch+1}/{num_epochs}]")
    print(f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_acc:.4f}")

    # ---------------- Validation Phase ----------------
    model.eval()
    test_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            with autocast(device_type='cuda', enabled=fp16):
                outputs = model(images)
                loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    test_acc = correct / total
    avg_test_loss = test_loss / len(test_loader)
    print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    # ---------------- Save Current Epoch Model ----------------
    model_path = f"/kaggle/working/saved_models3/swin_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved: {model_path}")

    # ---------------- Save Best Model ----------------
    if use_accuracy_for_best:
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "/kaggle/working/saved_models3/swin_best_model.pth")
            # print("Best model updated based on highest accuracy!")
    else:
        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            torch.save(model.state_dict(), "/kaggle/working/saved_models3/swin_best_model.pth")
            # print("Best model updated based on lowest validation loss!")

    early_stopper(avg_test_loss)
    if early_stopper.early_stop:
        print("Early stopping triggered. Ending training.")
        break

  0%|          | 0/20 [00:00<?, ?it/s]


Epoch [1/20]
Train Loss: 1.3932, Train Accuracy: 0.4874
Test Loss: 1.1476, Test Accuracy: 0.5793
Model saved: /kaggle/working/saved_models3/swin_epoch_1.pth


  5%|▌         | 1/20 [02:02<38:42, 122.25s/it]


Epoch [2/20]
Train Loss: 1.1051, Train Accuracy: 0.5980
Test Loss: 1.0482, Test Accuracy: 0.6269
Model saved: /kaggle/working/saved_models3/swin_epoch_2.pth


 10%|█         | 2/20 [04:04<36:42, 122.34s/it]


Epoch [3/20]
Train Loss: 0.9679, Train Accuracy: 0.6511
Test Loss: 0.9040, Test Accuracy: 0.6696
Model saved: /kaggle/working/saved_models3/swin_epoch_3.pth


 15%|█▌        | 3/20 [06:06<34:39, 122.33s/it]


Epoch [4/20]
Train Loss: 0.8720, Train Accuracy: 0.6879
Test Loss: 0.8541, Test Accuracy: 0.6977
Model saved: /kaggle/working/saved_models3/swin_epoch_4.pth


 20%|██        | 4/20 [08:09<32:36, 122.31s/it]


Epoch [5/20]
Train Loss: 0.7936, Train Accuracy: 0.7195


 25%|██▌       | 5/20 [10:11<30:33, 122.23s/it]

Test Loss: 0.9100, Test Accuracy: 0.6710
Model saved: /kaggle/working/saved_models3/swin_epoch_5.pth
EarlyStopping: 1/3 without improvement.

Epoch [6/20]
Train Loss: 0.7012, Train Accuracy: 0.7523
Test Loss: 0.8029, Test Accuracy: 0.7100
Model saved: /kaggle/working/saved_models3/swin_epoch_6.pth


 30%|███       | 6/20 [12:13<28:32, 122.32s/it]


Epoch [7/20]
Train Loss: 0.6033, Train Accuracy: 0.7863


 35%|███▌      | 7/20 [14:15<26:28, 122.17s/it]

Test Loss: 0.8165, Test Accuracy: 0.7135
Model saved: /kaggle/working/saved_models3/swin_epoch_7.pth
EarlyStopping: 1/3 without improvement.

Epoch [8/20]
Train Loss: 0.4760, Train Accuracy: 0.8358


 40%|████      | 8/20 [16:17<24:26, 122.19s/it]

Test Loss: 0.9015, Test Accuracy: 0.7152
Model saved: /kaggle/working/saved_models3/swin_epoch_8.pth
EarlyStopping: 2/3 without improvement.

Epoch [9/20]
Train Loss: 0.3547, Train Accuracy: 0.8819


 40%|████      | 8/20 [18:20<27:30, 137.52s/it]

Test Loss: 0.9514, Test Accuracy: 0.7088
Model saved: /kaggle/working/saved_models3/swin_epoch_9.pth
EarlyStopping: 3/3 without improvement.
Early stopping triggered. Ending training.





lr = 3e-4

In [12]:
# Set this to True to use FP16 training
fp16 = True
scaler = GradScaler(enabled=fp16)

num_epochs = 20
best_acc = 0.0
best_test_loss = float('inf')
use_accuracy_for_best = False  # Set to False to save best model by val loss

early_stopper = EarlyStopping(patience=3, verbose=True)

for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        with autocast(device_type='cuda', enabled=fp16):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    scheduler.step()
    
    train_acc = correct / total
    avg_train_loss = running_loss / len(train_loader)
    print(f"\nEpoch [{epoch+1}/{num_epochs}]")
    print(f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_acc:.4f}")

    # ---------------- Validation Phase ----------------
    model.eval()
    test_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            with autocast(device_type='cuda', enabled=fp16):
                outputs = model(images)
                loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    test_acc = correct / total
    avg_test_loss = test_loss / len(test_loader)
    print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    # ---------------- Save Current Epoch Model ----------------
    model_path = f"/kaggle/working/saved_models2/swin_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved: {model_path}")

    # ---------------- Save Best Model ----------------
    if use_accuracy_for_best:
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "/kaggle/working/saved_models2/swin_best_model.pth")
            # print("Best model updated based on highest accuracy!")
    else:
        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            torch.save(model.state_dict(), "/kaggle/working/saved_models2/swin_best_model.pth")
            # print("Best model updated based on lowest validation loss!")

    early_stopper(avg_test_loss)
    if early_stopper.early_stop:
        print("Early stopping triggered. Ending training.")
        break

  0%|          | 0/20 [00:00<?, ?it/s]


Epoch [1/20]
Train Loss: 1.2066, Train Accuracy: 0.5612
Test Loss: 1.0466, Test Accuracy: 0.6281


  5%|▌         | 1/20 [02:13<42:09, 133.11s/it]

Model saved: /kaggle/working/saved_models2/swin_epoch_1.pth

Epoch [2/20]
Train Loss: 0.9868, Train Accuracy: 0.6440
Test Loss: 0.9246, Test Accuracy: 0.6703
Model saved: /kaggle/working/saved_models2/swin_epoch_2.pth


 10%|█         | 2/20 [04:26<39:55, 133.10s/it]


Epoch [3/20]
Train Loss: 0.8768, Train Accuracy: 0.6876
Test Loss: 0.8728, Test Accuracy: 0.6808
Model saved: /kaggle/working/saved_models2/swin_epoch_3.pth


 15%|█▌        | 3/20 [06:39<37:43, 133.12s/it]


Epoch [4/20]
Train Loss: 0.7928, Train Accuracy: 0.7140


 20%|██        | 4/20 [08:52<35:28, 133.01s/it]

Test Loss: 0.8994, Test Accuracy: 0.6831
Model saved: /kaggle/working/saved_models2/swin_epoch_4.pth
EarlyStopping: 1/3 without improvement.

Epoch [5/20]
Train Loss: 0.7106, Train Accuracy: 0.7472
Test Loss: 0.8278, Test Accuracy: 0.7048
Model saved: /kaggle/working/saved_models2/swin_epoch_5.pth


 25%|██▌       | 5/20 [11:05<33:14, 132.96s/it]


Epoch [6/20]
Train Loss: 0.6107, Train Accuracy: 0.7845


 30%|███       | 6/20 [13:17<31:00, 132.90s/it]

Test Loss: 0.8679, Test Accuracy: 0.6951
Model saved: /kaggle/working/saved_models2/swin_epoch_6.pth
EarlyStopping: 1/3 without improvement.

Epoch [7/20]
Train Loss: 0.4827, Train Accuracy: 0.8320


 35%|███▌      | 7/20 [15:30<28:46, 132.83s/it]

Test Loss: 0.8893, Test Accuracy: 0.7168
Model saved: /kaggle/working/saved_models2/swin_epoch_7.pth
EarlyStopping: 2/3 without improvement.

Epoch [8/20]
Train Loss: 0.3451, Train Accuracy: 0.8828


 35%|███▌      | 7/20 [17:43<32:54, 151.90s/it]

Test Loss: 0.9999, Test Accuracy: 0.7101
Model saved: /kaggle/working/saved_models2/swin_epoch_8.pth
EarlyStopping: 3/3 without improvement.
Early stopping triggered. Ending training.





lr = 1e-4

In [8]:
# Set this to True to use FP16 training
fp16 = True
scaler = GradScaler(enabled=fp16)

num_epochs = 20
best_acc = 0.0
best_test_loss = float('inf')
use_accuracy_for_best = False  # Set to False to save best model by val loss

early_stopper = EarlyStopping(patience=3, verbose=True)

for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        with autocast(device_type='cuda', enabled=fp16):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    scheduler.step()
    
    train_acc = correct / total
    avg_train_loss = running_loss / len(train_loader)
    print(f"\nEpoch [{epoch+1}/{num_epochs}]")
    print(f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_acc:.4f}")

    # ---------------- Validation Phase ----------------
    model.eval()
    test_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            with autocast(device_type='cuda', enabled=fp16):
                outputs = model(images)
                loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    test_acc = correct / total
    avg_test_loss = test_loss / len(test_loader)
    print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    # ---------------- Save Current Epoch Model ----------------
    model_path = f"/kaggle/working/saved_models/swin_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved: {model_path}")

    # ---------------- Save Best Model ----------------
    if use_accuracy_for_best:
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "/kaggle/working/saved_models/swin_best_model.pth")
            # print("Best model updated based on highest accuracy!")
    else:
        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            torch.save(model.state_dict(), "/kaggle/working/saved_models/swin_best_model.pth")
            # print("Best model updated based on lowest validation loss!")

    early_stopper(avg_test_loss)
    if early_stopper.early_stop:
        print("Early stopping triggered. Ending training.")
        break

  0%|          | 0/20 [00:00<?, ?it/s]


Epoch [1/20]
Train Loss: 1.2177, Train Accuracy: 0.5572
Test Loss: 1.0966, Test Accuracy: 0.6047
Model saved: /kaggle/working/saved_models/swin_epoch_1.pth


  5%|▌         | 1/20 [02:13<42:13, 133.32s/it]


Epoch [2/20]
Train Loss: 0.9386, Train Accuracy: 0.6601
Test Loss: 0.8937, Test Accuracy: 0.6849
Model saved: /kaggle/working/saved_models/swin_epoch_2.pth


 10%|█         | 2/20 [04:25<39:44, 132.45s/it]


Epoch [3/20]
Train Loss: 0.8163, Train Accuracy: 0.7073
Test Loss: 0.8253, Test Accuracy: 0.7115
Model saved: /kaggle/working/saved_models/swin_epoch_3.pth


 15%|█▌        | 3/20 [06:36<37:26, 132.17s/it]


Epoch [4/20]
Train Loss: 0.7325, Train Accuracy: 0.7361
Test Loss: 0.7914, Test Accuracy: 0.7177
Model saved: /kaggle/working/saved_models/swin_epoch_4.pth


 20%|██        | 4/20 [08:48<35:12, 132.04s/it]


Epoch [5/20]
Train Loss: 0.6458, Train Accuracy: 0.7682
Test Loss: 0.8573, Test Accuracy: 0.6984


 25%|██▌       | 5/20 [11:00<32:58, 131.88s/it]

Model saved: /kaggle/working/saved_models/swin_epoch_5.pth
EarlyStopping: 1/3 without improvement.

Epoch [6/20]
Train Loss: 0.5343, Train Accuracy: 0.8105
Test Loss: 0.8663, Test Accuracy: 0.7106


 30%|███       | 6/20 [13:11<30:44, 131.76s/it]

Model saved: /kaggle/working/saved_models/swin_epoch_6.pth
EarlyStopping: 2/3 without improvement.

Epoch [7/20]
Train Loss: 0.4065, Train Accuracy: 0.8566


 30%|███       | 6/20 [15:23<35:54, 153.91s/it]

Test Loss: 0.8804, Test Accuracy: 0.7232
Model saved: /kaggle/working/saved_models/swin_epoch_7.pth
EarlyStopping: 3/3 without improvement.
Early stopping triggered. Ending training.





In [8]:
# Set this to True to use FP16 training
fp16 = True
scaler = GradScaler(enabled=fp16)

num_epochs = 20
best_acc = 0.0
best_test_loss = float('inf')
use_accuracy_for_best = False  # Set to False to save best model by val loss

early_stopper = EarlyStopping(patience=3, verbose=True)

for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        with autocast(device_type='cuda', enabled=fp16):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    scheduler.step()
    
    train_acc = correct / total
    avg_train_loss = running_loss / len(train_loader)
    print(f"\nEpoch [{epoch+1}/{num_epochs}]")
    print(f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_acc:.4f}")

    # ---------------- Validation Phase ----------------
    model.eval()
    test_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            with autocast(device_type='cuda', enabled=fp16):
                outputs = model(images)
                loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    test_acc = correct / total
    avg_test_loss = test_loss / len(test_loader)
    print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    # ---------------- Save Current Epoch Model ----------------
    model_path = f"/kaggle/working/saved_models/swin_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved: {model_path}")

    # ---------------- Save Best Model ----------------
    if use_accuracy_for_best:
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "/kaggle/working/saved_models/swin_best_model.pth")
            # print("Best model updated based on highest accuracy!")
    else:
        if avg_test_loss < best_test_loss:
            best_test_loss = avg_test_loss
            torch.save(model.state_dict(), "/kaggle/working/saved_models/swin_best_model.pth")
            # print("Best model updated based on lowest validation loss!")

    early_stopper(avg_test_loss)
    if early_stopper.early_stop:
        print("Early stopping triggered. Ending training.")
        break

  0%|          | 0/20 [00:00<?, ?it/s]


Epoch [1/20]
Train Loss: 1.1359, Train Accuracy: 0.5890
Test Loss: 0.9194, Test Accuracy: 0.6635


  5%|▌         | 1/20 [02:12<42:03, 132.80s/it]

Model saved: /kaggle/working/saved_models/swin_epoch_1.pth

Epoch [2/20]
Train Loss: 0.8681, Train Accuracy: 0.6912
Test Loss: 0.9150, Test Accuracy: 0.6849
Model saved: /kaggle/working/saved_models/swin_epoch_2.pth


 10%|█         | 2/20 [04:25<39:44, 132.48s/it]


Epoch [3/20]
Train Loss: 0.7645, Train Accuracy: 0.7301
Test Loss: 0.7847, Test Accuracy: 0.7159
Model saved: /kaggle/working/saved_models/swin_epoch_3.pth


 15%|█▌        | 3/20 [06:37<37:29, 132.30s/it]


Epoch [4/20]
Train Loss: 0.6679, Train Accuracy: 0.7629


 20%|██        | 4/20 [08:48<35:13, 132.12s/it]

Test Loss: 0.7990, Test Accuracy: 0.7180
Model saved: /kaggle/working/saved_models/swin_epoch_4.pth
EarlyStopping: 1/3 without improvement.

Epoch [5/20]
Train Loss: 0.5578, Train Accuracy: 0.8023


 25%|██▌       | 5/20 [11:01<33:01, 132.12s/it]

Test Loss: 0.8116, Test Accuracy: 0.7285
Model saved: /kaggle/working/saved_models/swin_epoch_5.pth
EarlyStopping: 2/3 without improvement.

Epoch [6/20]
Train Loss: 0.4231, Train Accuracy: 0.8529


 25%|██▌       | 5/20 [13:12<39:38, 158.60s/it]

Test Loss: 0.9393, Test Accuracy: 0.6975
Model saved: /kaggle/working/saved_models/swin_epoch_6.pth
EarlyStopping: 3/3 without improvement.
Early stopping triggered. Ending training.



