# JigSaw pretext task
Following [the original implementation](https://arxiv.org/pdf/1603.09246) <br>
Also useful to look at [the FAIR paper](https://arxiv.org/pdf/1905.01235) (page 12), for details on the implementation.<br>
Adapted to use ResNet18 instead of CFN.
TODO
- Organize into separate .py modules for JigSaw utils
- Create runner script that can use up to 4 GPUs for faster training.
- Add ViT
- Change training dataset to something w/ resolution of ~255x255 to avoid the need to upscate data
- Add the evaluations -> basically take the (pretrained) resnet module and plug it into another module w/ a clean classification head 
    - linear probing
    - full ft

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import transforms, models
from PIL import Image
import numpy as np
import os
import random
from tqdm import tqdm
from datasets import load_dataset
import datasets
from pathlib import Path

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Set random seed for reproducibility -> maybe use pytorch lightning for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x145da6bb3190>

## Load Tiny-ImageNet/ ImageNet-1k from HF
Should download imagenet1k and evaluate there, but first update the path to download datasets to https://discuss.huggingface.co/t/specifying-download-directory-for-custom-dataset-loading-script/11150

In [4]:
download_path = '/l/users/emilio.villa/huggingface/datasets'
datasets.config.DOWNLOADED_DATASETS_PATH = Path(download_path)

In [5]:
# dataset = load_dataset("zh-plus/tiny-imagenet", cache_dir = download_path)
## We can also download it from here http://cs231n.stanford.edu/tiny-imagenet-200.zip but i think HFs its easier
# dataset

In [16]:
## imagenet1k
access_token = "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"
dataset = load_dataset('ILSVRC/imagenet-1k', token = access_token, cache_dir = download_path)#, num_proc=8)
dataset

Loading dataset shards:   0%|          | 0/257 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/25 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 1281167
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 100000
    })
})

In [None]:
imagenet_1k

## Prepare dataset

In [8]:
def generate_permutations(n_permutations, n_tiles):
    """
    Generates a list of permutations, these will essentially be the 'gold truth' labels
    """
    permutations = []
    seen = set()
    while len(permutations) < n_permutations:
        perm = tuple(np.random.permutation(n_tiles))
        if perm not in seen:
            permutations.append(perm)
            seen.add(perm)
    return permutations

In [9]:
# Generate n_permutations permutations for 9 tiles, these values are related to the complexity of the task and should be updated as necessary
n_permutations = 1000 ##essentially the number of classes
n_tiles = 9
permutations = generate_permutations(n_permutations, n_tiles)

# Enumerate and store permutations
permutations_dict = {i: perm for i, perm in enumerate(permutations)} 

In [10]:
permutations[:3] 
print('These indices will basically be the possible "gold" labels :p')

These indices will basically be the possible "gold" labels :p


In [11]:
class JigsawPuzzleDataset(data.Dataset):
    def __init__(self, hf_dataset, permutations, transform=None):
        """
        Paramteres:
            hf_dataset: HuggingFace Dataset object.
            permutations: List of permutations.
            transform: Optional transform to be applied on a sample.
        """
        self.dataset = hf_dataset
        self.permutations = permutations
        self.n_permutations = len(permutations)
        self.n_tiles = 9  # 3x3 grid -> this can be modified later
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Load image from the HuggingFace dataset and convert to RGB
        image = self.dataset[idx]['image'].convert('RGB')  # Ensure image is in RGB

        # resize it to 255x255 we will reduce to 64x64 patches
        image = image.resize((255, 255))

        # divide the image into 3x3 grid of tiles (85x85 pixels each)
        tiles = []
        tile_size = 85 # 85 * 3 = 255

        ## Iterate over possible tiles and create the patches
        for i in range(3):
            for j in range(3):

                '''
                Original paper explanation: We randomly crop a 225 × 225 pixel window from an image (red dashed box), divide it into a 3 × 3 grid, and randomly pick a 64 × 64 pixel tiles from each 75 × 75 pixel cell.
                '''

                # Get boundaries and crop
                left = j * tile_size
                upper = i * tile_size
                right = left + tile_size
                lower = upper + tile_size
                tile = image.crop((left, upper, right, lower))
                
                # Random crop of 64x64 pixels with random shifts
                shift_max = tile_size - 64  # Max shift to introduce randomness
                left_shift = random.randint(0, shift_max)
                upper_shift = random.randint(0, shift_max)
                tile = tile.crop((left_shift, upper_shift, left_shift + 64, upper_shift + 64))
                
                # Apply any transform passed as argument
                if self.transform is not None:
                    tile = self.transform(tile)
                    
                tiles.append(tile)
                
        # Select a random permutation from the pre-computed permutations
        perm_idx = random.randint(0, self.n_permutations - 1)
        perm = self.permutations[perm_idx]

        # Shuffle the tiles according to the permutation
        shuffled_tiles = [tiles[p] for p in perm]

        # Stack tiles into a tensor
        tiles_tensor = torch.stack(shuffled_tiles, dim=0)  # Shape: [9, 3, 64, 64]

        # Return the shuffled tiles and the permutation index which is the gold label we aim the model to predict
        return tiles_tensor, perm_idx


In [12]:
transform = transforms.Compose([transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),transforms.ToTensor()])

## Instance the torch module

In [28]:
class JigsawNet(nn.Module):
    def __init__(self, 
                 n_permutations,
                 architecture = 'resnet', # 'resnet' or 'vit'
                ):
        
        super(JigsawNet, self).__init__()

        if architecture=='resnet':
            # Backbone ResNet model TODO: replace by ResNet 50
            self.resnet = models.resnet18(weights=None) 
            
            self.resnet.fc = nn.Identity()  #Remove the classification layer
            
        elif architecture=='vit':
            pass ##TODO

        # Fully connected layers << to dispose after the PTT
        self.fc = nn.Sequential(
            nn.Linear(512 * 9, 4096), # each genertaes a 512-dimensional vector
            nn.ReLU(),
            nn.Linear(4096, n_permutations)
        )

    def forward(self, x):
        # x shape: [batch_size, 9, 3, 64, 64]
        batch_size = x.size(0)

        # combine batch and tile dimensions (siamese network -> feed the same weights all the patches at once)
        x = x.view(batch_size * 9, 3, 64, 64)  
        features = self.resnet(x)  # Shape: [batch_size * 9, 512]

        # concatenate the patches before the linear layers that learns to predict the permutation
        features = features.view(batch_size, 9 * 512)  # shape -> [batch_size, 9 * 512]

        #
        out = self.fc(features)  # shape: [batch_size, n_permutations]
        return out

## Training
Hyperparameters should be adapted

In [29]:
model = JigsawNet(n_permutations=n_permutations)
model = model.to(device)

In [18]:
# Create the datasets and dataloaders
train_dataset = JigsawPuzzleDataset(dataset['train'], permutations, transform=transform)
# valid_dataset = JigsawPuzzleDataset(dataset['valid'], permutations, transform=transform)
valid_dataset = JigsawPuzzleDataset(dataset['validation'], permutations, transform=transform)

train_loader = data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=0)
valid_loader = data.DataLoader(valid_dataset, batch_size=256, shuffle=False, num_workers=0)

In [19]:
print('Shape of dataset output: {}'.format(next(iter(train_loader))[0].shape))

Shape of dataset output: torch.Size([256, 9, 3, 64, 64])


In [20]:
criterion = nn.CrossEntropyLoss()

In [22]:
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
num_epochs = 30
log_each = 100

In [25]:
model_name = 'jigsaw_rn50_imnt1k'

In [30]:
for epoch in range(num_epochs):
    model.train()
    avg_loss = 0.0
    for batch_idx, (tiles, perm_idx) in enumerate(tqdm(train_loader)):
        tiles = tiles.to(device)  # Shape: [batch_size, 9, 3, 64, 64]
        perm_idx = perm_idx.to(device)  # Shape: [batch_size]

        optimizer.zero_grad()
        outputs = model(tiles)  # Shape: [batch_size, n_permutations]
        loss = criterion(outputs, perm_idx)
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()
        if batch_idx % log_each == log_each - 1:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {avg_loss / 100:.4f}')
            avg_loss = 0.0
            
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for tiles, perm_idx in tqdm(valid_loader):
            tiles = tiles.to(device)
            perm_idx = perm_idx.to(device)

            outputs = model(tiles)
            loss = criterion(outputs, perm_idx)
            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += perm_idx.size(0)
            correct += (predicted == perm_idx).sum().item()

    val_accuracy = 100 * correct / total
    avg_val_loss = val_loss / len(valid_loader)
    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')
    if val_accuracy > best_accuracy:
        print('Saving checkpoint')
        torch.save(model.state_dict(), f'{model_name}.pth')


  2%|▏         | 100/5005 [06:12<4:58:45,  3.65s/it]

Epoch [1/30], Batch [100], Loss: 6.9366


  4%|▍         | 200/5005 [12:21<4:54:15,  3.67s/it]

Epoch [1/30], Batch [200], Loss: 6.9307


  6%|▌         | 300/5005 [18:42<4:49:00,  3.69s/it]

Epoch [1/30], Batch [300], Loss: 6.9345


  8%|▊         | 400/5005 [24:52<4:39:24,  3.64s/it]

Epoch [1/30], Batch [400], Loss: 6.9312


 10%|▉         | 500/5005 [30:58<4:42:15,  3.76s/it]

Epoch [1/30], Batch [500], Loss: 6.9347


 12%|█▏        | 600/5005 [37:01<4:28:26,  3.66s/it]

Epoch [1/30], Batch [600], Loss: 6.9344


 14%|█▍        | 700/5005 [43:08<4:25:41,  3.70s/it]

Epoch [1/30], Batch [700], Loss: 6.9349


 16%|█▌        | 800/5005 [49:14<4:13:45,  3.62s/it]

Epoch [1/30], Batch [800], Loss: 6.9311


 18%|█▊        | 900/5005 [57:41<6:47:06,  5.95s/it]

Epoch [1/30], Batch [900], Loss: 6.9330


 20%|█▉        | 1000/5005 [1:06:29<6:07:00,  5.50s/it]

Epoch [1/30], Batch [1000], Loss: 6.9339


 22%|██▏       | 1100/5005 [1:14:43<5:47:08,  5.33s/it]

Epoch [1/30], Batch [1100], Loss: 6.9354


 24%|██▍       | 1200/5005 [1:22:22<4:54:56,  4.65s/it]

Epoch [1/30], Batch [1200], Loss: 6.9348


 26%|██▌       | 1300/5005 [1:30:49<5:01:24,  4.88s/it]

Epoch [1/30], Batch [1300], Loss: 6.9332


 28%|██▊       | 1400/5005 [1:38:54<4:40:56,  4.68s/it]

Epoch [1/30], Batch [1400], Loss: 6.9354


 30%|██▉       | 1500/5005 [1:46:46<4:55:10,  5.05s/it]

Epoch [1/30], Batch [1500], Loss: 6.9328


 32%|███▏      | 1600/5005 [1:55:04<4:46:43,  5.05s/it]

Epoch [1/30], Batch [1600], Loss: 6.9340


 34%|███▍      | 1700/5005 [2:03:50<5:18:17,  5.78s/it]

Epoch [1/30], Batch [1700], Loss: 6.9331


 36%|███▌      | 1800/5005 [2:10:44<3:18:01,  3.71s/it]

Epoch [1/30], Batch [1800], Loss: 6.9329


 38%|███▊      | 1900/5005 [2:17:13<3:31:04,  4.08s/it]

Epoch [1/30], Batch [1900], Loss: 6.9318


 40%|███▉      | 2000/5005 [2:23:48<3:20:34,  4.00s/it]

Epoch [1/30], Batch [2000], Loss: 6.9311


 42%|████▏     | 2100/5005 [2:30:08<3:03:05,  3.78s/it]

Epoch [1/30], Batch [2100], Loss: 6.9331


 44%|████▍     | 2200/5005 [2:36:24<2:56:21,  3.77s/it]

Epoch [1/30], Batch [2200], Loss: 6.9321


 46%|████▌     | 2300/5005 [2:42:48<2:44:25,  3.65s/it]

Epoch [1/30], Batch [2300], Loss: 6.9325


 48%|████▊     | 2400/5005 [2:49:19<3:04:24,  4.25s/it]

Epoch [1/30], Batch [2400], Loss: 6.9306


 50%|████▉     | 2500/5005 [2:55:57<2:35:04,  3.71s/it]

Epoch [1/30], Batch [2500], Loss: 6.9317


 53%|█████▎    | 2658/5005 [3:06:28<2:44:39,  4.21s/it]

KeyboardInterrupt



In [None]:
tiles.shape

In [24]:
## save only resnet
torch.save(model.resnet.state_dict(), f'{model_name}_resnet.pth')

NameError: name 'model_name' is not defined

In [18]:
model.eval()
val_loss = 0.0
correct = 0
total = 0
## just evaluate
with torch.no_grad():
    for tiles, perm_idx in tqdm(valid_loader):
        tiles = tiles.to(device)
        perm_idx = perm_idx.to(device)

        outputs = model(tiles)
        loss = criterion(outputs, perm_idx)
        val_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += perm_idx.size(0)
        correct += (predicted == perm_idx).sum().item()

val_accuracy = 100 * correct / total
avg_val_loss = val_loss / len(valid_loader)
print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

100%|██████████| 40/40 [01:03<00:00,  1.58s/it]

Validation Loss: 0.8383, Validation Accuracy: 68.19%





### Training linear classification head

In [14]:
model = JigsawNet(n_permutations=n_permutations).to(device)
model_path = '/home/emilio.villa/nlp_local/cv_ptt/jigsaw_rn18_tinyimnt.pth'
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()

JigsawNet(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_ru

In [19]:
## save only resnet
model_rn_path = '/home/emilio.villa/nlp_local/cv_ptt/jigsaw_rn18_tinyimnt_resnet.pth'
torch.save(model.resnet.state_dict(), model_rn_path)

In [42]:
class ClassificationModel(nn.Module):
    def __init__(
        self, 
        num_classes,
        architecture = 'resnet', #'resnet 
        ):
        super(ClassificationModel, self).__init__()
        # Use the pretrained ResNet model from the PTT task
        
        if architecture=='resnet':
            # Backbone ResNet model TODO: replace by ResNet 50
            self.features = models.resnet18() 
            # self.resnet.fc = nn.Identity()
            # self.features.fc = nn.Identity()  #Remove the classification layer ### IS this necessary??
            
        elif architecture=='vit':
            pass ##TODO
            
        # Classification layer
        self.linear_proj = nn.Linear(512, num_classes)

    def forward(self, x):
        # x shape: [batch_size, 3, H, W]
        features = self.features(x)  # Shape: [batch_size, 512]
        out = self.linear_proj(features)  # Shape: [batch_size, num_classes]
        return out


# Instantiate the transfer learning model
num_classes = 200  # Tiny ImageNet has 200 classes
# resnet_model = model.resnet
classifier = ClassificationModel(num_classes=num_classes, architecture = 'resnet')
classifier = classifier.to(device)

# for param in classifier.features.parameters():
#     param.requires_grad = False


In [34]:
## load pretrained resnet weights into the CLF model
# pretrained_path = '/home/emilio.villa/nlp_local/cv_ptt/jigsaw_rn18_tinyimnt_resnet.pth'
# classifier.features.load_state_dict(torch.load(pretrained_path, weights_only=True))

<All keys matched successfully>

In [39]:
# ## freeze backbone parameters (linear_probing) or don't do this for the
for param in classifier.features.parameters():
    param.requires_grad = False

In [40]:
# Define the transforms
classification_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

# Define custom dataset class
class ClassificationDataset(data.Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image'].convert('RGB')
        label = self.dataset[idx]['label']
        if self.transform:
            image = self.transform(image)
        return image, label

# Instantiate datasets and data loaders
train_classification_dataset = ClassificationDataset(tinyImageNet_dataset['train'], transform=classification_transform)
val_classification_dataset = ClassificationDataset(tinyImageNet_dataset['valid'], transform=classification_transform)

train_classification_loader = data.DataLoader(
    train_classification_dataset, batch_size=64, shuffle=True, num_workers=0
)
val_classification_loader = data.DataLoader(
    val_classification_dataset, batch_size=64, shuffle=False, num_workers=0
)


In [41]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(classifier.linear_proj.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

# Training loop for transfer learning
num_epochs = 15  # Adjust the number of epochs as needed

for epoch in range(num_epochs):
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in tqdm(train_classification_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = classifier(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_accuracy = 100 * correct / total
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_classification_loader):.4f}, '
          f'Train Accuracy: {train_accuracy:.2f}%')

    # Validation loop
    classifier.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in val_classification_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = classifier(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = 100 * correct / total
    avg_val_loss = val_loss / len(val_classification_loader)
    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

100%|██████████| 1563/1563 [02:27<00:00, 10.63it/s]


Epoch [1/15], Loss: 5.1392, Train Accuracy: 2.68%
Validation Loss: 4.9973, Validation Accuracy: 3.80%


100%|██████████| 1563/1563 [02:37<00:00,  9.95it/s]


Epoch [2/15], Loss: 4.9252, Train Accuracy: 4.47%
Validation Loss: 4.8776, Validation Accuracy: 4.80%


100%|██████████| 1563/1563 [02:45<00:00,  9.47it/s]


Epoch [3/15], Loss: 4.8237, Train Accuracy: 5.38%
Validation Loss: 4.8147, Validation Accuracy: 5.20%


100%|██████████| 1563/1563 [02:44<00:00,  9.51it/s]


Epoch [4/15], Loss: 4.7573, Train Accuracy: 6.05%
Validation Loss: 4.7609, Validation Accuracy: 6.31%


100%|██████████| 1563/1563 [02:30<00:00, 10.38it/s]


Epoch [5/15], Loss: 4.7052, Train Accuracy: 6.87%
Validation Loss: 4.7215, Validation Accuracy: 6.67%


100%|██████████| 1563/1563 [02:49<00:00,  9.22it/s]


Epoch [6/15], Loss: 4.6665, Train Accuracy: 7.15%
Validation Loss: 4.6870, Validation Accuracy: 6.96%


100%|██████████| 1563/1563 [02:34<00:00, 10.08it/s]


Epoch [7/15], Loss: 4.6338, Train Accuracy: 7.60%
Validation Loss: 4.6672, Validation Accuracy: 7.28%


100%|██████████| 1563/1563 [02:55<00:00,  8.89it/s]


Epoch [8/15], Loss: 4.6013, Train Accuracy: 7.87%
Validation Loss: 4.6382, Validation Accuracy: 7.85%


100%|██████████| 1563/1563 [02:26<00:00, 10.63it/s]


Epoch [9/15], Loss: 4.5801, Train Accuracy: 8.12%
Validation Loss: 4.6449, Validation Accuracy: 7.59%


100%|██████████| 1563/1563 [02:49<00:00,  9.22it/s]


Epoch [10/15], Loss: 4.5573, Train Accuracy: 8.68%
Validation Loss: 4.6266, Validation Accuracy: 7.88%


100%|██████████| 1563/1563 [02:27<00:00, 10.61it/s]


Epoch [11/15], Loss: 4.5364, Train Accuracy: 8.79%
Validation Loss: 4.6232, Validation Accuracy: 7.71%


100%|██████████| 1563/1563 [02:33<00:00, 10.16it/s]


Epoch [12/15], Loss: 4.5180, Train Accuracy: 9.10%
Validation Loss: 4.6077, Validation Accuracy: 7.74%


100%|██████████| 1563/1563 [02:26<00:00, 10.68it/s]


Epoch [13/15], Loss: 4.5026, Train Accuracy: 9.20%
Validation Loss: 4.5907, Validation Accuracy: 8.20%


 78%|███████▊  | 1217/1563 [01:54<00:32, 10.62it/s]


KeyboardInterrupt: 

In [20]:
#now running only linear unfrozen and 

1

In [None]:
'''
15 epoch w/ supposedly pretrained weights and linear cls head

100%|██████████| 1563/1563 [02:37<00:00,  9.89it/s]
Epoch [1/15], Loss: 5.3484, Train Accuracy: 1.14%
Validation Loss: 5.2455, Validation Accuracy: 2.05%
100%|██████████| 1563/1563 [02:28<00:00, 10.50it/s]
Epoch [2/15], Loss: 5.2022, Train Accuracy: 2.18%
Validation Loss: 5.1732, Validation Accuracy: 2.21%
100%|██████████| 1563/1563 [02:24<00:00, 10.79it/s]
Epoch [3/15], Loss: 5.1099, Train Accuracy: 2.87%
Validation Loss: 5.0519, Validation Accuracy: 2.84%
100%|██████████| 1563/1563 [02:26<00:00, 10.68it/s]
Epoch [4/15], Loss: 5.0460, Train Accuracy: 3.43%
Validation Loss: 5.0245, Validation Accuracy: 3.62%
100%|██████████| 1563/1563 [02:24<00:00, 10.85it/s]
Epoch [5/15], Loss: 4.9947, Train Accuracy: 3.78%
Validation Loss: 4.9768, Validation Accuracy: 4.19%
100%|██████████| 1563/1563 [02:25<00:00, 10.75it/s]
Epoch [6/15], Loss: 4.9522, Train Accuracy: 4.26%
Validation Loss: 4.9880, Validation Accuracy: 3.43%
100%|██████████| 1563/1563 [03:02<00:00,  8.58it/s]
Epoch [7/15], Loss: 4.9156, Train Accuracy: 4.53%
Validation Loss: 4.8988, Validation Accuracy: 4.77%
100%|██████████| 1563/1563 [02:24<00:00, 10.82it/s]
Epoch [8/15], Loss: 4.8836, Train Accuracy: 4.96%
Validation Loss: 4.8678, Validation Accuracy: 5.11%
100%|██████████| 1563/1563 [02:34<00:00, 10.12it/s]
Epoch [9/15], Loss: 4.8599, Train Accuracy: 5.09%
Validation Loss: 4.9123, Validation Accuracy: 4.35%
100%|██████████| 1563/1563 [02:26<00:00, 10.68it/s]
Epoch [10/15], Loss: 4.8314, Train Accuracy: 5.39%
Validation Loss: 4.8441, Validation Accuracy: 5.19%
100%|██████████| 1563/1563 [02:26<00:00, 10.69it/s]
Epoch [11/15], Loss: 4.8131, Train Accuracy: 5.66%
Validation Loss: 4.8576, Validation Accuracy: 5.14%
100%|██████████| 1563/1563 [02:28<00:00, 10.50it/s]
Epoch [12/15], Loss: 4.7973, Train Accuracy: 5.63%
Validation Loss: 4.8183, Validation Accuracy: 5.17%
100%|██████████| 1563/1563 [02:25<00:00, 10.75it/s]
Epoch [13/15], Loss: 4.7789, Train Accuracy: 6.00%
Validation Loss: 4.8186, Validation Accuracy: 5.46%
100%|██████████| 1563/1563 [02:27<00:00, 10.58it/s]
Epoch [14/15], Loss: 4.7599, Train Accuracy: 6.20%
Validation Loss: 4.7920, Validation Accuracy: 5.58%
100%|██████████| 1563/1563 [02:24<00:00, 10.81it/s]
Epoch [15/15], Loss: 4.7487, Train Accuracy: 6.28%
Validation Loss: 4.7666, Validation Accuracy: 5.74%
'''

'''
15 epoch with supposedly Randomly initialized  weights :( better performance. Why ?
100%|██████████| 1563/1563 [02:27<00:00, 10.63it/s]
Epoch [1/15], Loss: 5.1392, Train Accuracy: 2.68%
Validation Loss: 4.9973, Validation Accuracy: 3.80%
100%|██████████| 1563/1563 [02:37<00:00,  9.95it/s]
Epoch [2/15], Loss: 4.9252, Train Accuracy: 4.47%
Validation Loss: 4.8776, Validation Accuracy: 4.80%
100%|██████████| 1563/1563 [02:45<00:00,  9.47it/s]
Epoch [3/15], Loss: 4.8237, Train Accuracy: 5.38%
Validation Loss: 4.8147, Validation Accuracy: 5.20%
100%|██████████| 1563/1563 [02:44<00:00,  9.51it/s]
Epoch [4/15], Loss: 4.7573, Train Accuracy: 6.05%
Validation Loss: 4.7609, Validation Accuracy: 6.31%
100%|██████████| 1563/1563 [02:30<00:00, 10.38it/s]
Epoch [5/15], Loss: 4.7052, Train Accuracy: 6.87%
Validation Loss: 4.7215, Validation Accuracy: 6.67%
100%|██████████| 1563/1563 [02:49<00:00,  9.22it/s]
Epoch [6/15], Loss: 4.6665, Train Accuracy: 7.15%
Validation Loss: 4.6870, Validation Accuracy: 6.96%
100%|██████████| 1563/1563 [02:34<00:00, 10.08it/s]
Epoch [7/15], Loss: 4.6338, Train Accuracy: 7.60%
Validation Loss: 4.6672, Validation Accuracy: 7.28%
100%|██████████| 1563/1563 [02:55<00:00,  8.89it/s]
Epoch [8/15], Loss: 4.6013, Train Accuracy: 7.87%
Validation Loss: 4.6382, Validation Accuracy: 7.85%
100%|██████████| 1563/1563 [02:26<00:00, 10.63it/s]
Epoch [9/15], Loss: 4.5801, Train Accuracy: 8.12%
Validation Loss: 4.6449, Validation Accuracy: 7.59%
100%|██████████| 1563/1563 [02:49<00:00,  9.22it/s]
Epoch [10/15], Loss: 4.5573, Train Accuracy: 8.68%
Validation Loss: 4.6266, Validation Accuracy: 7.88%
100%|██████████| 1563/1563 [02:27<00:00, 10.61it/s]
Epoch [11/15], Loss: 4.5364, Train Accuracy: 8.79%
Validation Loss: 4.6232, Validation Accuracy: 7.71%
100%|██████████| 1563/1563 [02:33<00:00, 10.16it/s]
Epoch [12/15], Loss: 4.5180, Train Accuracy: 9.10%
Validation Loss: 4.6077, Validation Accuracy: 7.74%
100%|██████████| 1563/1563 [02:26<00:00, 10.68it/s]
Epoch [13/15], Loss: 4.5026, Train Accuracy: 9.20%
Validation Loss: 4.5907, Validation Accuracy: 8.20%
'''