In [None]:
# # Install required libraries (run only once)
!pip install torch torchvision tqdm seaborn sklearn transformers wandb==0.19.6+computecanada argparse

In [2]:
import os
import glob
import time
import random
import re
from collections import defaultdict
from enum import Enum

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms, models
from torchvision.transforms import (
    RandomResizedCrop,
    RandomHorizontalFlip,
    RandomVerticalFlip,
    ColorJitter,
    RandomGrayscale,
    RandomApply,
    Compose,
    GaussianBlur,
    ToTensor,
    Normalize,
    CenterCrop,
    Resize
)

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
from PIL import Image, ImageEnhance, ImageOps
import seaborn as sns
import pandas as pd
import cv2

from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.preprocessing import normalize
from sklearn.metrics import roc_curve
from sklearn.decomposition import PCA

# Import wandb
import wandb


In [3]:
# Set device
torch.cuda.set_device("cuda:0")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'DEVICE: {DEVICE}')


DEVICE: cuda


In [4]:
run_name = 'simclr_rad-dino_pos-pairs_aug-pairs_100_epoch_m&ms_2'

In [5]:
class Config:
    def __init__(self):
        self.learning_rate = 0.001
        self.num_epochs = 100
        self.batch_size = 70  # Adjust as needed
        self.patience = 30
        self.dropout_p = 0.3
        self.image_shape = [256, 256]
        self.kernel_size = [21, 21]  # For the transforms, 10% of image size
        self.embedding_size = 128
        self.scheduler_step_size = 70
        self.scheduler_gamma = 0.1
        self.weight_decay = 1e-5
        self.max_norm = 1.0  # Gradient clipping
        self.temperature = 2.0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.base_path = f"/home/saahmed/scratch/projects/Image-segmentation/Different-Data/rerieval_checkpoint/{run_name}"
        os.makedirs(self.base_path, exist_ok=True)
        self.best_model_path = os.path.join(self.base_path, "best_model.pth")
        self.last_model_path = os.path.join(self.base_path, "last_model.pth")
        self.learning_plot_path = os.path.join(self.base_path, "learning_curves.png")

config = Config()

In [6]:
os.environ["WANDB_API_KEY"] = "4f8dccbaced16f201316dd4113139739694dfd3b"

In [7]:
# Initialize wandb and log the configuration parameters.
wandb.init(
    project="simclr-training",
    name=run_name,
    config={
        "learning_rate": config.learning_rate,
        "num_epochs": config.num_epochs,
        "batch_size": config.batch_size,
        "dropout_p": config.dropout_p,
        "image_shape": config.image_shape,
        "embedding_size": config.embedding_size,
        "scheduler_step_size": config.scheduler_step_size,
        "scheduler_gamma": config.scheduler_gamma,
        "weight_decay": config.weight_decay,
        "max_norm": config.max_norm,
        "temperature": config.temperature,
    }
)

# Optionally, add the config to wandb for reference
wandb_config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33msalmagg[0m ([33mmy_research_projects[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [8]:
# ls /home/saahmed/scratch/projects/Image-segmentation/datasets/MAndMs/processed_data/Training

In [9]:
def convert_to_rgb(img):
    return img.convert("RGB")

class AugmentationSequenceType(Enum):
    temp = "temp"
    normal = "normal"

augmentation_sequence_map = {
    AugmentationSequenceType.temp.value: transforms.Compose([
        transforms.Resize((config.image_shape[0], config.image_shape[1])),
        transforms.Lambda(convert_to_rgb),
        transforms.RandomRotation(degrees=10),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.Lambda(lambda img: transforms.functional.adjust_contrast(img, contrast_factor=random.uniform(1, 1.3))),
        transforms.ToTensor(),
    ]),
    AugmentationSequenceType.normal.value: transforms.Compose([
        transforms.Resize((config.image_shape[0], config.image_shape[1])),
        transforms.Lambda(convert_to_rgb),
        transforms.ToTensor(),
    ]),
}

class ContrastiveLearningViewGenerator(object):
    def __init__(self, base_transform, normal_transform, n_views=2):
        self.base_transform = base_transform
        self.normal_transform = normal_transform
        self.n_views = n_views

    def __call__(self, x):
        if random.random() < 0.5:
            views = [self.base_transform(x) for _ in range(self.n_views)]
        else:
            views = [self.normal_transform(x), self.base_transform(x)]
        return views

class CombinedContrastiveDataset(Dataset):
    def __init__(self, list_images, positive_pairs, base_transform, normal_transform):
        self.list_images = list_images
        self.positive_pairs = positive_pairs
        self.all_images = self.positive_pairs + self.list_images
        self.base_transform = base_transform
        self.normal_transform = normal_transform
        self.view_generator = ContrastiveLearningViewGenerator(base_transform, normal_transform, n_views=2)
        
    def __len__(self):
        return len(self.list_images) + len(self.positive_pairs)
    
    def __getitem__(self, idx):
        if idx < len(self.positive_pairs):
            img_path1, img_path2 = self.all_images[idx]
            img1 = Image.open(img_path1)
            img2 = Image.open(img_path2)
            img1 = self.normal_transform(img1) 
            img2 = self.normal_transform(img2)
            return [img1, img2]
        else:
            img_path = self.all_images[idx]
            img = Image.open(img_path)
            views = self.view_generator(img)
            return views

# Prepare training image list
images_list_train = []
for i in ['es', 'ed']:
    path = f'/home/saahmed/scratch/projects/Image-segmentation/datasets/MAndMs/processed_data/Training/{i}/images/'
    images = [os.path.join(path, fname) for fname in os.listdir(path)]
    images_list_train += images

def train_val_test_split(list_filenames, train_size=0.7):
    list_filenames_train, list_filenames_val = train_test_split(
        list_filenames,
        train_size=train_size,
        shuffle=True,
        random_state=42)
    return list_filenames_train, list_filenames_val

list_images = images_list_train
list_images_train, list_images_val = train_val_test_split(list_images)

print("Total number of images: ", len(list_images))
print("Images in train split: ", len(list_images_train))
print("Images in validation split: ", len(list_images_val))

import os
import re
from collections import defaultdict

def create_positive_pairs(images):
    file_list = sorted(images)
    pattern = re.compile(r'([^/]+)/images/([A-Z0-9]+)_slice_(\d+)\.png')  # Matches ed/es and filename
    groups = defaultdict(list)

    for path in file_list:
        match = pattern.search(path)
        if match:
            phase = match.group(1)          # 'ed' or 'es'
            patient = match.group(2)        # e.g., 'C6J5P1'
            slice_num = int(match.group(3)) # e.g., 11
            key = (patient, phase)
            groups[key].append((slice_num, path))
        else:
            print(f"File {path} does not match the expected pattern.")
    
    positive_pairs = []
    for key, slices in groups.items():
        slices.sort(key=lambda x: x[0])  # Sort by slice number
        for i in range(len(slices) - 1):
            curr_slice, img1 = slices[i]
            next_slice, img2 = slices[i + 1]
            if curr_slice + 1 == next_slice:
                positive_pairs.append((img1, img2))

    return positive_pairs


pos_pairs_train = create_positive_pairs(list_images_train)
pos_pairs_val = create_positive_pairs(list_images_val)

output_shape = config.image_shape 
base_transforms = augmentation_sequence_map[AugmentationSequenceType.temp.value]
normal_transforms = augmentation_sequence_map[AugmentationSequenceType.normal.value]

image_ds_train = CombinedContrastiveDataset(
    list_images=list_images_train,
    positive_pairs=pos_pairs_train,
    base_transform=base_transforms,
    normal_transform=normal_transforms)

image_ds_val = CombinedContrastiveDataset(
    list_images=list_images_val,
    positive_pairs=pos_pairs_val,
    base_transform=base_transforms,
    normal_transform=normal_transforms)

BATCH_SIZE = config.batch_size

train_loader = DataLoader(
    image_ds_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
)

val_loader = DataLoader(
    image_ds_val,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
)

print("Batches in TRAIN: ", len(train_loader))
print("Batches in VAL: ", len(val_loader))
# Note: If you have a test_loader, print it similarly.

Total number of images:  3286
Images in train split:  2300
Images in validation split:  986
Batches in TRAIN:  53
Batches in VAL:  17


In [10]:
print("samples in TRAIN: ", len(image_ds_train))
print("samples in VAL: ", len(image_ds_val))

samples in TRAIN:  3749
samples in VAL:  1255


In [11]:
from transformers import Dinov2Model

class SimCLR(nn.Module):
    def __init__(self, dropout_p=0.5, embedding_size=128, freeze=False, linear_eval=False):
        super().__init__()
        self.linear_eval = linear_eval
        self.dropout_p = dropout_p
        self.embedding_size = embedding_size

        # Load the DINOv2 model (you can change to any pretrained model)
        self.encoder = Dinov2Model.from_pretrained('microsoft/rad-dino')
        if freeze:
            for param in self.encoder.parameters():
                param.requires_grad = False

        self.projection = nn.Sequential(
            nn.Linear(768, 256),  # Assuming DINOv2 has an embedding dimension of 768
            nn.Dropout(p=self.dropout_p),
            nn.ReLU(),
            nn.Linear(256, embedding_size)
        )

    def forward(self, x):
        if not self.linear_eval:
            x = torch.cat(x, dim=0)  # Concatenate the two views
        outputs = self.encoder(x)
        encoding = outputs.last_hidden_state[:, 0]  # Extract the [CLS] token representation
        projection = self.projection(encoding)
        return projection

def save_model(model, save_path):
    model.encoder.save_pretrained(save_path)
    torch.save(model.projection.state_dict(), os.path.join(save_path, 'projection_head.pth'))

def load_model(model_class, load_path, device):
    encoder = Dinov2Model.from_pretrained(load_path)
    model = model_class()
    model.encoder = encoder
    projection_head_path = os.path.join(load_path, 'projection_head.pth')
    model.projection.load_state_dict(torch.load(projection_head_path, map_location=device))
    return model

def plot_training(train_loss_history, save_path, val_loss_history=None):
    plt.figure(figsize=(10, 5))
    plt.plot(train_loss_history, label='Train Loss')
    if val_loss_history is not None:
        plt.plot(val_loss_history, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(save_path)
    plt.show()

def contrastrive_loss(features, config):
    """NT-Xent (Normalized Temperature-Scaled Cross Entropy) Loss,
    aka. Contrastive Loss, used in the SimCLR paper.

    IMPORTANT NOTE: We don't really return the loss, but the logits
    and the (synthetic) labels to compute it with CrossEntropyLoss!

    The main idea behind SimCLR and contrastive learning is to learn
    representations that are close for positive pairs and far for negative pairs.
    In the case of SimCLR, a positive pair is two different augmentations
    of the same image, and a negative pair is two augmentations
    of two different images.

    How NT-Xent works:
    - Compute the cosine similarity between the representations
    of all pairs of images in the batch.
    - Apply a softmax to these similarities, but treat the similarity
    of each image with its positive pair as the correct class.
    This means that for each image, the goal is to make the
    softmax probability of its positive pair as high as possible,
    and the softmax probabilities of its negative pairs as low as possible.
    - Compute the cross entropy between these softmax probabilities
    and the true labels (which have a 1 for the positive pair
    and 0 for the negative pairs).
    - The temperature parameter scales the similarities before the softmax.
    A lower temperature makes the softmax output more peaky
    (i.e., the highest value will be much higher than the others,
    and the lower values will be closer to zero),
    while a higher temperature makes the softmax output more uniform.

    Args:
        projections: cat(z1, z2)
        z1: The projection of the first branch/view
        z2: The projeciton of the second branch/view

    Returns:
        the NTxent loss

    Notes on the shapes:
        inputs to model (views): [(B, C, W, H), (B, C, W, H)]
            B: batch size
            C: channels
            W: width
            H: height
            E: embedding size
        outputs from model (projections): [2*B, E]
        LABELS: [2*B, 2*B]
        features = outputs from model: [2*B, E]
        mask: [2*B, 2*B]
        similarity_matrix: [2*B, 2*B-1]
        positives: [2*B, 1]
        negatives: [2*B, 2*B-2]
        logits: [2*B, 2*B-1]
        labels: [2*B]
    """
    # FIXME: Refactor: take config out and pass necessary params, remove capital variables, etc.
    # FIXME: convert into class
    BATCH_SIZE = config.batch_size
    DEVICE = config.device
    TEMPERATURE = config.temperature

    LABELS = torch.cat([torch.arange(BATCH_SIZE) for i in range(2)], dim=0)
    LABELS = (LABELS.unsqueeze(0) == LABELS.unsqueeze(1)).float() # Creates a one-hot with broadcasting
    LABELS = LABELS.to(DEVICE) # 2*B, 2*B

    similarity_matrix = torch.matmul(features, features.T) # 2*B, 2*B
    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(LABELS.shape[0], dtype=torch.bool).to(DEVICE)
    # ~mask is the negative of the mask
    # the view is required to bring the matrix back to shape
    labels = LABELS[~mask].view(LABELS.shape[0], -1) # 2*B, 2*B-1
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) # 2*B, 2*B-1

    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) # 2*B, 1

    # select only the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) # 2*B, 2*B-2

    logits = torch.cat([positives, negatives], dim=1) # 2*B, 2*B-1
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(DEVICE)

    logits = logits / TEMPERATURE

    return logits, labels

In [12]:
model = SimCLR(dropout_p=config.dropout_p, embedding_size=config.embedding_size).to(config.device)
criterion = nn.CrossEntropyLoss().to(config.device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
scheduler = StepLR(optimizer, step_size=config.scheduler_step_size, gamma=config.scheduler_gamma)

# Optionally, let wandb watch the model (logs gradients and parameters)
# wandb.watch(model, log="all")

In [13]:
wandb.watch(model, log="all")

In [15]:
def validate(model, val_loader, criterion, config):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for views in val_loader:
            projections = model([view.to(config.device) for view in views])
            logits, labels = contrastrive_loss(projections, config)
            loss = criterion(logits, labels)
            val_loss += loss.item()
    return val_loss / len(val_loader)

def train(model, train_loader, val_loader, criterion, optimizer, scheduler, config, output_freq=2, debug=False):
    model = model.to(config.device)
    train_loss_history = []
    val_loss_history = []
    best_val_loss = float('inf')
    no_improve_epochs = 0
    total_batches = len(train_loader)
    print_every = total_batches // output_freq

    for epoch in tqdm(range(config.num_epochs)):
        start_time = time.time()
        train_loss = 0.0
        model.train()

        for i, views in enumerate(train_loader):
            projections = model([view.to(config.device) for view in views])
            logits, labels = contrastrive_loss(projections, config)
            if debug and (torch.isnan(logits).any() or torch.isinf(logits).any()):
                print("[WARNING]: large logits")
                logits = logits.clamp(min=-10, max=10)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm)
            optimizer.step()
            train_loss += loss.item()

        scheduler.step()
        train_loss /= len(train_loader)
        train_loss_history.append(train_loss)

        val_loss = validate(model, val_loader, criterion, config)
        val_loss_history.append(val_loss)

        epoch_time = time.time() - start_time
        current_lr = scheduler.get_last_lr()[0]

        # Log metrics to wandb
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "learning_rate": current_lr,
            "epoch_time": epoch_time
        })

        print(f"Epoch: {epoch+1}, Loss: {train_loss}, Val Loss: {val_loss}, Time: {epoch_time:.2f}s, LR: {current_lr}")

        # Save the last model checkpoint locally and log it as an artifact if needed.
        save_model(model, config.last_model_path)
        # artifact = wandb.Artifact("last-model", type="model", metadata={"epoch": epoch+1})
        # artifact.add_dir(config.last_model_path)
        # wandb.log_artifact(artifact, aliases=["latest"])

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(model, config.best_model_path)

            # # Create a wandb artifact and log the best model checkpoint
            # artifact = wandb.Artifact("best-model", type="model", metadata={"epoch": epoch+1})
            # artifact.add_dir(config.best_model_path)
            # wandb.log_artifact(artifact, aliases=["latest"])
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= config.patience:
                print("Early stopping")
                break

    return train_loss_history, val_loss_history

In [None]:
train_loss_history, val_loss_history = train(model, train_loader, val_loader, criterion, optimizer, scheduler, config)

  0%|                                                                                                                     | 0/100 [00:00<?, ?it/s]

Epoch: 1, Loss: 4.641082448779412, Val Loss: 3.9483172332539276, Time: 182.58s, LR: 0.001


  1%|█                                                                                                         | 1/100 [03:03<5:03:09, 183.73s/it]

Epoch: 2, Loss: 3.837065300851498, Val Loss: 3.461095262976254, Time: 178.50s, LR: 0.001


  2%|██                                                                                                        | 2/100 [06:03<4:56:21, 181.44s/it]

Epoch: 3, Loss: 3.6371613493505515, Val Loss: 3.3486707911771885, Time: 177.33s, LR: 0.001


  3%|███▏                                                                                                      | 3/100 [09:02<4:51:16, 180.16s/it]

Epoch: 4, Loss: 3.209782595904368, Val Loss: 2.631897435468786, Time: 176.44s, LR: 0.001


  4%|████▏                                                                                                     | 4/100 [11:59<4:46:38, 179.15s/it]

Epoch: 5, Loss: 2.92814638929547, Val Loss: 2.7288812048294964, Time: 176.50s, LR: 0.001


  5%|█████▎                                                                                                    | 5/100 [14:56<4:42:30, 178.43s/it]

Epoch: 6, Loss: 2.8500386229101218, Val Loss: 2.3228602128870346, Time: 177.28s, LR: 0.001


  6%|██████▎                                                                                                   | 6/100 [17:55<4:39:37, 178.48s/it]

Epoch: 7, Loss: 2.7320603649571256, Val Loss: 2.2521372262169335, Time: 176.54s, LR: 0.001


  7%|███████▍                                                                                                  | 7/100 [20:53<4:36:21, 178.30s/it]

Epoch: 8, Loss: 2.673531302865946, Val Loss: 2.2015566966112923, Time: 176.86s, LR: 0.001


  8%|████████▍                                                                                                 | 8/100 [23:51<4:33:19, 178.26s/it]

Epoch: 9, Loss: 2.660335752199281, Val Loss: 2.2510033074547264, Time: 176.83s, LR: 0.001


  9%|█████████▌                                                                                                | 9/100 [26:49<4:30:00, 178.03s/it]

Epoch: 10, Loss: 2.5930257023505443, Val Loss: 2.0824207137612736, Time: 176.26s, LR: 0.001


 10%|██████████▌                                                                                              | 10/100 [29:46<4:26:49, 177.88s/it]

Epoch: 11, Loss: 2.4716402314743906, Val Loss: 1.8759993875727934, Time: 176.89s, LR: 0.001


 11%|███████████▌                                                                                             | 11/100 [32:44<4:23:59, 177.97s/it]

Epoch: 12, Loss: 2.3278168282418883, Val Loss: 1.8016362330492806, Time: 175.74s, LR: 0.001


 12%|████████████▌                                                                                            | 12/100 [35:41<4:20:36, 177.69s/it]

Epoch: 13, Loss: 2.2196289008518435, Val Loss: 1.664936717818765, Time: 176.88s, LR: 0.001


 13%|█████████████▋                                                                                           | 13/100 [38:40<4:17:52, 177.85s/it]

Epoch: 14, Loss: 2.167031609787131, Val Loss: 1.537646770477295, Time: 180.18s, LR: 0.001


 14%|██████████████▋                                                                                          | 14/100 [41:41<4:16:32, 178.98s/it]

Epoch: 15, Loss: 2.1963647289096184, Val Loss: 1.5174087426241707, Time: 199.26s, LR: 0.001


 15%|███████████████▊                                                                                         | 15/100 [45:02<4:22:49, 185.53s/it]

Epoch: 16, Loss: 2.164521005918395, Val Loss: 1.4792499892851885, Time: 210.42s, LR: 0.001


 16%|████████████████▊                                                                                        | 16/100 [48:34<4:30:52, 193.48s/it]

Epoch: 17, Loss: 2.063102416272433, Val Loss: 1.52752545300652, Time: 210.57s, LR: 0.001


 17%|█████████████████▊                                                                                       | 17/100 [52:05<4:35:03, 198.84s/it]

Epoch: 18, Loss: 2.0267876544088685, Val Loss: 1.3678675960091984, Time: 209.76s, LR: 0.001


 18%|██████████████████▉                                                                                      | 18/100 [55:36<4:36:51, 202.58s/it]

Epoch: 19, Loss: 1.9492189749231879, Val Loss: 1.344238982481115, Time: 211.35s, LR: 0.001


 19%|███████████████████▉                                                                                     | 19/100 [59:09<4:37:38, 205.66s/it]

Epoch: 20, Loss: 1.8313930709407014, Val Loss: 1.2518778548521154, Time: 210.51s, LR: 0.001


 20%|████████████████████▌                                                                                  | 20/100 [1:02:41<4:36:43, 207.55s/it]

Epoch: 21, Loss: 1.7716761512576409, Val Loss: 1.0826472394606645, Time: 209.36s, LR: 0.001


 21%|█████████████████████▋                                                                                 | 21/100 [1:06:12<4:34:33, 208.52s/it]

Epoch: 22, Loss: 1.7087831452207745, Val Loss: 1.150752505835365, Time: 213.81s, LR: 0.001


 22%|██████████████████████▋                                                                                | 22/100 [1:09:47<4:33:26, 210.34s/it]

Epoch: 23, Loss: 1.7086199162141331, Val Loss: 1.0544680392040926, Time: 213.04s, LR: 0.001


 23%|███████████████████████▋                                                                               | 23/100 [1:13:21<4:31:34, 211.62s/it]

Epoch: 24, Loss: 1.56961560699175, Val Loss: 0.9987157337805804, Time: 213.58s, LR: 0.001


 24%|████████████████████████▋                                                                              | 24/100 [1:16:56<4:29:22, 212.66s/it]

Epoch: 25, Loss: 1.6224152412054673, Val Loss: 0.9554168652085697, Time: 214.42s, LR: 0.001


 25%|█████████████████████████▊                                                                             | 25/100 [1:20:32<4:27:04, 213.66s/it]

Epoch: 26, Loss: 1.5343552467957982, Val Loss: 0.9551829415209153, Time: 208.21s, LR: 0.001


 26%|██████████████████████████▊                                                                            | 26/100 [1:24:02<4:22:02, 212.47s/it]

Epoch: 27, Loss: 1.4967596800822132, Val Loss: 0.8340931745136485, Time: 197.27s, LR: 0.001


 27%|███████████████████████████▊                                                                           | 27/100 [1:27:21<4:13:29, 208.35s/it]

Epoch: 28, Loss: 1.5127446988843523, Val Loss: 0.9317857588038725, Time: 199.19s, LR: 0.001


 28%|████████████████████████████▊                                                                          | 28/100 [1:30:41<4:06:59, 205.82s/it]

Epoch: 29, Loss: 1.471505291057083, Val Loss: 0.849110368420096, Time: 216.93s, LR: 0.001


 29%|█████████████████████████████▊                                                                         | 29/100 [1:34:18<4:07:48, 209.41s/it]

Epoch: 30, Loss: 1.3863979623002827, Val Loss: 0.8709590855766746, Time: 218.41s, LR: 0.001


 30%|██████████████████████████████▉                                                                        | 30/100 [1:37:58<4:07:42, 212.32s/it]

Epoch: 31, Loss: 1.3872895330752966, Val Loss: 0.8485395382432377, Time: 222.97s, LR: 0.001


 31%|███████████████████████████████▉                                                                       | 31/100 [1:41:41<4:08:07, 215.76s/it]

Epoch: 32, Loss: 1.3455763189297802, Val Loss: 0.8350731099353117, Time: 229.72s, LR: 0.001


 32%|████████████████████████████████▉                                                                      | 32/100 [1:45:32<4:09:33, 220.20s/it]

Epoch: 33, Loss: 1.314924418926239, Val Loss: 0.7771610828006968, Time: 231.55s, LR: 0.001


 33%|█████████████████████████████████▉                                                                     | 33/100 [1:49:25<4:10:16, 224.12s/it]

Epoch: 34, Loss: 1.2657198546067723, Val Loss: 0.763131278402665, Time: 249.09s, LR: 0.001


 34%|███████████████████████████████████                                                                    | 34/100 [1:53:36<4:15:16, 232.07s/it]

Epoch: 35, Loss: 1.3173803383449338, Val Loss: 0.7699547865811516, Time: 251.51s, LR: 0.001


 35%|████████████████████████████████████                                                                   | 35/100 [1:57:48<4:18:00, 238.17s/it]

Epoch: 36, Loss: 1.2951058239307043, Val Loss: 0.7500349598772386, Time: 251.08s, LR: 0.001


 36%|█████████████████████████████████████                                                                  | 36/100 [2:02:01<4:18:43, 242.55s/it]

Epoch: 37, Loss: 1.1874089027350803, Val Loss: 0.71996548947166, Time: 258.75s, LR: 0.001


 37%|██████████████████████████████████████                                                                 | 37/100 [2:06:21<4:20:19, 247.93s/it]

Epoch: 38, Loss: 1.244825595954679, Val Loss: 0.6884725374333999, Time: 253.09s, LR: 0.001


 38%|███████████████████████████████████████▏                                                               | 38/100 [2:10:36<4:18:21, 250.03s/it]

Epoch: 39, Loss: 1.1356240207294248, Val Loss: 0.6202783794964061, Time: 256.08s, LR: 0.001


 39%|████████████████████████████████████████▏                                                              | 39/100 [2:14:54<4:16:32, 252.34s/it]

Epoch: 40, Loss: 1.11373311618589, Val Loss: 0.6174658116172341, Time: 253.37s, LR: 0.001


 40%|█████████████████████████████████████████▏                                                             | 40/100 [2:19:09<4:13:07, 253.12s/it]

Epoch: 41, Loss: 1.0540108950632923, Val Loss: 0.6367689861970789, Time: 255.35s, LR: 0.001


 41%|██████████████████████████████████████████▏                                                            | 41/100 [2:23:25<4:09:47, 254.02s/it]

Epoch: 42, Loss: 1.094182958018105, Val Loss: 0.6534037975703969, Time: 257.23s, LR: 0.001


 42%|███████████████████████████████████████████▎                                                           | 42/100 [2:27:43<4:06:43, 255.23s/it]

Epoch: 43, Loss: 1.0115317702293396, Val Loss: 0.5699596562806297, Time: 265.09s, LR: 0.001


 43%|████████████████████████████████████████████▎                                                          | 43/100 [2:32:10<4:05:50, 258.77s/it]

Epoch: 44, Loss: 1.025298334517569, Val Loss: 0.6237384343848509, Time: 267.02s, LR: 0.001


 44%|█████████████████████████████████████████████▎                                                         | 44/100 [2:36:38<4:04:05, 261.53s/it]

Epoch: 45, Loss: 0.9937038646554047, Val Loss: 0.595291207818424, Time: 267.75s, LR: 0.001


 45%|██████████████████████████████████████████████▎                                                        | 45/100 [2:41:07<4:01:41, 263.67s/it]

Epoch: 46, Loss: 0.9550303378195133, Val Loss: 0.5529804562821108, Time: 264.97s, LR: 0.001


 46%|███████████████████████████████████████████████▍                                                       | 46/100 [2:45:34<3:58:07, 264.59s/it]