In [1]:
import os
import time
import random
import glob
from PIL import Image
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import warnings
from tqdm import tqdm

warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision import datasets, transforms, models
import torchvision.transforms.functional as TF
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from sklearn import linear_model, model_selection

In [2]:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)
random.seed(42)

In [3]:
# Read the input data
train_metadata_path = "custom_korean_family_dataset_resolution_128\custom_train_dataset.csv"
train_data_path = "custom_korean_family_dataset_resolution_128\\train_images"

test_metadata_path = "custom_korean_family_dataset_resolution_128\custom_val_dataset.csv"
test_data_path = "custom_korean_family_dataset_resolution_128\\val_images"

unseen_metadata_path = "custom_korean_family_dataset_resolution_128\custom_test_dataset.csv"
unseen_data_path = "custom_korean_family_dataset_resolution_128\\test_images"

train_metadata = pd.read_csv(train_metadata_path)
test_metadata = pd.read_csv(test_metadata_path)
unseen_metadata = pd.read_csv(unseen_metadata_path)

In [4]:
class Dataset(Dataset):
    def __init__(self, metadata, img_dir, transform = None, forget = False, retain = False):
        self.metadata = metadata
        self.img_dir = img_dir
        self.transform = transform

        # Processing the metadata 
        image_age_list = []
        for _, row in metadata.iterrows():
            img_path = row['image_path']
            age = row['age_class']
            image_age_list.append([img_path, age])
        
        self.image_age_list = image_age_list
        self.age_to_label = {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4, "f": 5, "g": 6, "h": 7}

        if forget:
            self.image_age_list = self.image_age_list[:1500]
        if retain:
            self.image_age_list = self.image_age_list[1500:]
        
    def __len__(self):
        return len(self.image_age_list)

    def __getitem__(self, index):
        img_path, age = self.image_age_list[index]
        img = Image.open(os.path.join(self.img_dir, img_path))
        label = self.age_to_label[age]

        if self.transform:
            img = self.transform(img)
        
        return img, label

In [5]:
train_transform = transforms.Compose([
    transforms.Resize(128),
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
    transforms.ToTensor()
])

transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor()
])

train_data = Dataset(train_metadata, train_data_path, train_transform)
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)

test_data = Dataset(test_metadata, test_data_path, transform)
test_dataloader = DataLoader(test_data, batch_size=64)

unseen_data = Dataset(unseen_metadata, unseen_data_path, transform)
unseen_dataloader = DataLoader(unseen_data, batch_size=64)

In [6]:
# Train, test and unseen data split
print(f"Number of Training images: {train_data.__len__()}")
print(f"Number of Testing images: {test_data.__len__()}")
print(f"Number of Unseen images: {unseen_data.__len__()}")

Number of Training images: 10025
Number of Testing images: 1539
Number of Unseen images: 1504


In [7]:
label_to_age = {
    0: "0-6 years old",
    1: "7-12 years old",
    2: "13-19 years old",
    3: "20-30 years old",
    4: "31-45 years old",
    5: "46-55 years old",
    6: "56-66 years old",
    7: "67-80 years old"
}

In [8]:
from torch.utils.data import Dataset, Subset

In [10]:
class ShardSliceTrainer:
    def __init__(self, shards, slices, train_loaders_lists, test_loader, label_to_age):
        self.shards = shards
        self.slices = slices
        self.train_loaders_lists = train_loaders_lists
        self.test_loader = test_loader
        self.label_to_age = label_to_age

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.models = []
        self.optimizers = []
        self.criterion = nn.CrossEntropyLoss().to(self.device)  # Move criterion to GPU
        for _ in range(shards):
            model = models.resnet18(pretrained=True)
            # Freeze all layers except the final few layers
            for param in model.parameters():
                param.requires_grad = False
            for param in model.layer4.parameters():
                param.requires_grad = True
            model.fc = nn.Linear(model.fc.in_features, len(label_to_age))
            model = model.to(self.device)
            optimizer = optim.Adam(model.parameters(), lr=0.001)
            self.models.append(model)
            self.optimizers.append(optimizer)

    def train(self):
        for shard_idx, model in enumerate(self.models):
            print(f"Training Shard {shard_idx+1}/{self.shards}")
            train_loaders = self.train_loaders_lists[shard_idx]

            for slice_idx, loaders in enumerate(train_loaders):
                print(f"Shard: {shard_idx + 1} | Training Slice {slice_idx+1}/{self.slices}")
                running_loss = 0.0
                correct = 0
                total = 0

                # Concatenate datasets within the current shard up to the current slice
                if slice_idx == 0:
                    cumulative_dataset = loaders.dataset
                else:
                    cumulative_dataset = torch.utils.data.ConcatDataset([cumulative_dataset, loaders.dataset])

                # Create a DataLoader from the concatenated dataset
                cumulative_loader = DataLoader(cumulative_dataset, batch_size=1, shuffle=True)

                # Train the model for dynamic number of epochs on cumulative data
                for epoch in tqdm(range(slice_idx + 1), desc=f"Epoch {slice_idx+1}/{self.slices}"):
                    for data in cumulative_loader:
                        inputs = data[0].to(self.device)
                        labels = data[1].to(self.device)
                        optimizer = self.optimizers[shard_idx]
                        optimizer.zero_grad()
                        outputs = model(inputs)
                        loss = self.criterion(outputs, labels)
                        loss.backward()
                        optimizer.step()
                        running_loss += loss.item()

                        _, predicted = torch.max(outputs, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()

                accuracy = 100 * correct / total
                print(f"Shard {shard_idx}, Slice {slice_idx+1}, Loss: {running_loss/(total):.2f}, Accuracy: {accuracy:.2f}%")
            print()
    
    def test(self):
        print("Testing...")
        correct = 0
        total = 0
        total_loss = 0.0

        # Iterate through each test data point
        with torch.no_grad():
            for data in self.test_loader:
                inputs, labels = data[0].to(self.device), data[1].to(self.device)
                ensemble_outputs = torch.zeros(inputs.size(0), len(self.label_to_age)).to(self.device)  # Initialize ensemble outputs
                ensemble_loss = 0.0

                # Aggregate predictions and losses from all models in the ensemble
                for model in self.models:
                    outputs = model(inputs)
                    ensemble_outputs += outputs
                    loss = self.criterion(outputs, labels)
                    ensemble_loss += loss.item()

                # Choose prediction with maximum probability in ensemble outputs
                _, predicted = torch.max(ensemble_outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            # Accumulate loss from each model and calculate average loss
            total_loss += ensemble_loss / len(self.models)

        # Calculate accuracy and average loss over all batches
        accuracy = 100 * correct / total
        loss = total_loss / len(self.test_loader)  # Calculate average loss over all batches
        print(f"Ensemble Test Accuracy: {accuracy:.2f}%")
        print(f"Ensemble Test Loss: {loss:.2f}")
                    
                        
    def get_slice_index(self, image_index, shards, slices, dataset_length):
        samples_per_shard = dataset_length // shards
        samples_per_slice = samples_per_shard // slices

        # Iterate through each shard
        for shard_idx in range(shards):
            start_idx = shard_idx * samples_per_shard
            end_idx = (shard_idx + 1) * samples_per_shard

            # Check if the image index falls within the range of samples for the current shard
            if start_idx <= image_index < end_idx:
                # Calculate the relative index within the shard
                relative_index = image_index - start_idx

                # Calculate the slice index
                slice_index = relative_index // samples_per_slice
                return shard_idx, slice_index

        return None 

    def find_slice_idx(self, train_loader, image):
        slice_idx_list = 0
        for i in range(len(train_loader.dataset)):
            if torch.equal(image, train_loader.dataset[i][0]):
                slice_idx_list = i
                break
        return slice_idx_list
    
    def unlearn(self, dataloader, image_index):
        # Get shard and slice index for the image
        shard_idx, slice_idx = self.get_slice_index(image_index, self.shards, self.slices, len(dataloader.dataset))
        
        # Retrieve the model and optimizer for the identified shard
        model = self.models[shard_idx]
        optimizer = self.optimizers[shard_idx]
        
        image, label = dataloader.dataset[image_index]  # Extract image and label from dataloader

        # Gather data from the identified slice, excluding the image to be removed
        shard_slice_data_loader = self.train_loaders_lists[shard_idx][slice_idx]
        updated_data = []
        for data, target in shard_slice_data_loader:
            if not torch.equal(data, image):  # Exclude the image to be removed
                updated_data.append((data, target))
        
        # Create a new DataLoader or Subset for the retraining data
        retraining_data_loader = DataLoader(updated_data, batch_size=1, shuffle=True)
        
        # Retrain the model using the retraining data
        for epoch in range(slice_idx + 1):
            for batch in retraining_data_loader:
                inputs, labels = batch[0], batch[1]
                inputs = inputs[0].to(self.device)
                labels = labels[0].to(self.device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                optimizer.step()
        
        # Optionally return the updated model
        return model



In [16]:
shards = 4
slices = 10

shard_data_loaders_lists = []

for shard_idx in range(shards):
    start_idx = (len(train_dataloader.dataset) // shards) * shard_idx
    end_idx = (len(train_dataloader.dataset) // shards) * (shard_idx + 1)

    shard_subset = Subset(train_dataloader.dataset, range(start_idx, end_idx))

    samples_per_slice = len(shard_subset) // slices

    shard_slice_data_loaders = []

    for slice_idx in range(slices):
        slice_start_idx = slice_idx * samples_per_slice
        slice_end_idx = (slice_idx + 1) * samples_per_slice if slice_idx < slices - 1 else len(shard_subset)

        slice_subset = Subset(shard_subset, range(slice_start_idx, slice_end_idx))

        slice_data_loader = DataLoader(slice_subset, batch_size=1, shuffle=True)

        shard_slice_data_loaders.append(slice_data_loader)

    shard_data_loaders_lists.append(shard_slice_data_loaders)

In [17]:
shard_slice_trainer = ShardSliceTrainer(shards, slices, shard_data_loaders_lists, test_dataloader, label_to_age)
shard_slice_trainer.train()

Training Shard 1/4
Shard: 1 | Training Slice 1/10


Epoch 1/10: 100%|██████████| 1/1 [00:02<00:00,  2.40s/it]


Shard 0, Slice 1, Loss: 2.27, Accuracy: 17.60%
Shard: 1 | Training Slice 2/10


Epoch 2/10: 100%|██████████| 2/2 [00:10<00:00,  5.06s/it]


Shard 0, Slice 2, Loss: 2.02, Accuracy: 26.20%
Shard: 1 | Training Slice 3/10


Epoch 3/10: 100%|██████████| 3/3 [00:23<00:00,  7.96s/it]


Shard 0, Slice 3, Loss: 1.37, Accuracy: 48.53%
Shard: 1 | Training Slice 4/10


Epoch 4/10: 100%|██████████| 4/4 [00:40<00:00, 10.08s/it]


Shard 0, Slice 4, Loss: 0.63, Accuracy: 78.10%
Shard: 1 | Training Slice 5/10


Epoch 5/10: 100%|██████████| 5/5 [00:52<00:00, 10.57s/it]


Shard 0, Slice 5, Loss: 0.22, Accuracy: 93.94%
Shard: 1 | Training Slice 6/10


Epoch 6/10: 100%|██████████| 6/6 [01:20<00:00, 13.50s/it]


Shard 0, Slice 6, Loss: 0.22, Accuracy: 94.26%
Shard: 1 | Training Slice 7/10


Epoch 7/10: 100%|██████████| 7/7 [01:41<00:00, 14.46s/it]


Shard 0, Slice 7, Loss: 0.14, Accuracy: 96.40%
Shard: 1 | Training Slice 8/10


Epoch 8/10: 100%|██████████| 8/8 [02:26<00:00, 18.28s/it]


Shard 0, Slice 8, Loss: 0.12, Accuracy: 96.63%
Shard: 1 | Training Slice 9/10


Epoch 9/10: 100%|██████████| 9/9 [03:45<00:00, 25.06s/it]


Shard 0, Slice 9, Loss: 0.08, Accuracy: 98.04%
Shard: 1 | Training Slice 10/10


Epoch 10/10: 100%|██████████| 10/10 [04:56<00:00, 29.64s/it]


Shard 0, Slice 10, Loss: 0.08, Accuracy: 97.84%

Training Shard 2/4
Shard: 2 | Training Slice 1/10


Epoch 1/10: 100%|██████████| 1/1 [00:03<00:00,  3.36s/it]


Shard 1, Slice 1, Loss: 2.43, Accuracy: 14.40%
Shard: 2 | Training Slice 2/10


Epoch 2/10: 100%|██████████| 2/2 [00:11<00:00,  5.87s/it]


Shard 1, Slice 2, Loss: 1.91, Accuracy: 27.60%
Shard: 2 | Training Slice 3/10


Epoch 3/10: 100%|██████████| 3/3 [00:27<00:00,  9.14s/it]


Shard 1, Slice 3, Loss: 1.06, Accuracy: 61.91%
Shard: 2 | Training Slice 4/10


Epoch 4/10: 100%|██████████| 4/4 [00:39<00:00,  9.85s/it]


Shard 1, Slice 4, Loss: 0.34, Accuracy: 89.58%
Shard: 2 | Training Slice 5/10


Epoch 5/10: 100%|██████████| 5/5 [01:07<00:00, 13.43s/it]


Shard 1, Slice 5, Loss: 0.24, Accuracy: 93.57%
Shard: 2 | Training Slice 6/10


Epoch 6/10: 100%|██████████| 6/6 [01:45<00:00, 17.51s/it]


Shard 1, Slice 6, Loss: 0.21, Accuracy: 94.21%
Shard: 2 | Training Slice 7/10


Epoch 7/10: 100%|██████████| 7/7 [01:58<00:00, 16.87s/it]


Shard 1, Slice 7, Loss: 0.14, Accuracy: 95.97%
Shard: 2 | Training Slice 8/10


Epoch 8/10: 100%|██████████| 8/8 [02:18<00:00, 17.28s/it]


Shard 1, Slice 8, Loss: 0.09, Accuracy: 97.58%
Shard: 2 | Training Slice 9/10


Epoch 9/10: 100%|██████████| 9/9 [02:53<00:00, 19.30s/it]


Shard 1, Slice 9, Loss: 0.07, Accuracy: 98.13%
Shard: 2 | Training Slice 10/10


Epoch 10/10: 100%|██████████| 10/10 [03:28<00:00, 20.88s/it]


Shard 1, Slice 10, Loss: 0.06, Accuracy: 98.42%

Training Shard 3/4
Shard: 3 | Training Slice 1/10


Epoch 1/10: 100%|██████████| 1/1 [00:02<00:00,  2.10s/it]


Shard 2, Slice 1, Loss: 2.27, Accuracy: 15.60%
Shard: 3 | Training Slice 2/10


Epoch 2/10: 100%|██████████| 2/2 [00:08<00:00,  4.23s/it]


Shard 2, Slice 2, Loss: 1.88, Accuracy: 28.60%
Shard: 3 | Training Slice 3/10


Epoch 3/10: 100%|██████████| 3/3 [00:20<00:00,  6.68s/it]


Shard 2, Slice 3, Loss: 1.20, Accuracy: 54.49%
Shard: 3 | Training Slice 4/10


Epoch 4/10: 100%|██████████| 4/4 [00:33<00:00,  8.46s/it]


Shard 2, Slice 4, Loss: 0.56, Accuracy: 80.55%
Shard: 3 | Training Slice 5/10


Epoch 5/10: 100%|██████████| 5/5 [00:54<00:00, 10.95s/it]


Shard 2, Slice 5, Loss: 0.30, Accuracy: 91.60%
Shard: 3 | Training Slice 6/10


Epoch 6/10: 100%|██████████| 6/6 [01:12<00:00, 12.12s/it]


Shard 2, Slice 6, Loss: 0.15, Accuracy: 95.99%
Shard: 3 | Training Slice 7/10


Epoch 7/10: 100%|██████████| 7/7 [01:40<00:00, 14.29s/it]


Shard 2, Slice 7, Loss: 0.11, Accuracy: 97.40%
Shard: 3 | Training Slice 8/10


Epoch 8/10: 100%|██████████| 8/8 [02:16<00:00, 17.04s/it]


Shard 2, Slice 8, Loss: 0.08, Accuracy: 98.05%
Shard: 3 | Training Slice 9/10


Epoch 9/10: 100%|██████████| 9/9 [02:46<00:00, 18.45s/it]


Shard 2, Slice 9, Loss: 0.08, Accuracy: 97.94%
Shard: 3 | Training Slice 10/10


Epoch 10/10: 100%|██████████| 10/10 [03:26<00:00, 20.64s/it]


Shard 2, Slice 10, Loss: 0.04, Accuracy: 99.04%

Training Shard 4/4
Shard: 4 | Training Slice 1/10


Epoch 1/10: 100%|██████████| 1/1 [00:02<00:00,  2.45s/it]


Shard 3, Slice 1, Loss: 2.10, Accuracy: 19.60%
Shard: 4 | Training Slice 2/10


Epoch 2/10: 100%|██████████| 2/2 [00:08<00:00,  4.09s/it]


Shard 3, Slice 2, Loss: 1.65, Accuracy: 38.00%
Shard: 4 | Training Slice 3/10


Epoch 3/10: 100%|██████████| 3/3 [00:18<00:00,  6.06s/it]


Shard 3, Slice 3, Loss: 0.92, Accuracy: 65.82%
Shard: 4 | Training Slice 4/10


Epoch 4/10: 100%|██████████| 4/4 [00:34<00:00,  8.51s/it]


Shard 3, Slice 4, Loss: 0.45, Accuracy: 85.12%
Shard: 4 | Training Slice 5/10


Epoch 5/10: 100%|██████████| 5/5 [00:51<00:00, 10.21s/it]


Shard 3, Slice 5, Loss: 0.21, Accuracy: 94.35%
Shard: 4 | Training Slice 6/10


Epoch 6/10: 100%|██████████| 6/6 [01:12<00:00, 12.07s/it]


Shard 3, Slice 6, Loss: 0.14, Accuracy: 96.26%
Shard: 4 | Training Slice 7/10


Epoch 7/10: 100%|██████████| 7/7 [01:47<00:00, 15.35s/it]


Shard 3, Slice 7, Loss: 0.11, Accuracy: 97.18%
Shard: 4 | Training Slice 8/10


Epoch 8/10: 100%|██████████| 8/8 [02:09<00:00, 16.23s/it]


Shard 3, Slice 8, Loss: 0.07, Accuracy: 98.24%
Shard: 4 | Training Slice 9/10


Epoch 9/10: 100%|██████████| 9/9 [02:48<00:00, 18.68s/it]


Shard 3, Slice 9, Loss: 0.07, Accuracy: 98.17%
Shard: 4 | Training Slice 10/10


Epoch 10/10: 100%|██████████| 10/10 [03:46<00:00, 22.67s/it]

Shard 3, Slice 10, Loss: 0.06, Accuracy: 98.67%






In [19]:
shard_slice_trainer.test()

Testing...
Ensemble Test Accuracy: 32.29%
Ensemble Test Loss: 3.28


In [13]:
# Unlearning

import torch
from torch.utils.data import DataLoader

total_samples = len(train_dataloader.dataset)
num_images = 10

indices = list(range(total_samples))

torch.manual_seed(42)  
random_indices = torch.randperm(total_samples)

idx = random_indices[0]

image, label = train_dataloader.dataset[idx]
image

tensor([[[0.1529, 0.1529, 0.1333,  ..., 0.2471, 0.2392, 0.2510],
         [0.1255, 0.1373, 0.1490,  ..., 0.2078, 0.2784, 0.2353],
         [0.1059, 0.0902, 0.1059,  ..., 0.3020, 0.2549, 0.2392],
         ...,
         [0.8118, 0.8196, 0.8314,  ..., 0.8627, 0.7961, 0.7804],
         [0.7922, 0.8078, 0.8118,  ..., 0.9137, 0.8078, 0.8157],
         [0.7922, 0.8039, 0.8078,  ..., 0.8784, 0.8549, 0.7686]],

        [[0.1176, 0.1176, 0.1098,  ..., 0.2627, 0.2549, 0.2667],
         [0.1020, 0.1137, 0.1333,  ..., 0.2235, 0.2941, 0.2510],
         [0.1020, 0.0863, 0.1020,  ..., 0.3176, 0.2706, 0.2549],
         ...,
         [0.2588, 0.2667, 0.2784,  ..., 0.7647, 0.6980, 0.6824],
         [0.2706, 0.2784, 0.2784,  ..., 0.8078, 0.7020, 0.7098],
         [0.2863, 0.2902, 0.2745,  ..., 0.7725, 0.7451, 0.6588]],

        [[0.1216, 0.1216, 0.1176,  ..., 0.2588, 0.2510, 0.2627],
         [0.1020, 0.1137, 0.1294,  ..., 0.2196, 0.2980, 0.2471],
         [0.0863, 0.0706, 0.0863,  ..., 0.3216, 0.2824, 0.

In [14]:
shard_slice_trainer.find_slice_idx(train_dataloader, image)

5992

In [15]:
shard_slice_trainer.unlearn(train_dataloader, 5992)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
SHARD = 2; SLICE = 5

'''
Training Shard 1/2
Shard: 1 | Training Slice 1/5
Epoch 1/5: 100%|██████████| 1/1 [00:11<00:00, 11.38s/it]
Shard 0, Slice 1, Loss: 2.18, Accuracy: 20.66%
Shard: 1 | Training Slice 2/5
Epoch 2/5: 100%|██████████| 2/2 [00:34<00:00, 17.29s/it]
Shard 0, Slice 2, Loss: 1.70, Accuracy: 33.98%
Shard: 1 | Training Slice 3/5
Epoch 3/5: 100%|██████████| 3/3 [01:18<00:00, 26.14s/it]
Shard 0, Slice 3, Loss: 1.13, Accuracy: 54.60%
Shard: 1 | Training Slice 4/5
Epoch 4/5: 100%|██████████| 4/4 [02:22<00:00, 35.72s/it]
Shard 0, Slice 4, Loss: 0.53, Accuracy: 80.06%
Shard: 1 | Training Slice 5/5
Epoch 5/5: 100%|██████████| 5/5 [03:35<00:00, 43.09s/it]
Shard 0, Slice 5, Loss: 0.21, Accuracy: 93.29%

Training Shard 2/2
Shard: 2 | Training Slice 1/5
Epoch 1/5: 100%|██████████| 1/1 [00:10<00:00, 10.63s/it]
Shard 1, Slice 1, Loss: 2.12, Accuracy: 20.56%
Shard: 2 | Training Slice 2/5
Epoch 2/5: 100%|██████████| 2/2 [00:39<00:00, 19.69s/it]
Shard 1, Slice 2, Loss: 1.44, Accuracy: 42.44%
Shard: 2 | Training Slice 3/5
Epoch 3/5: 100%|██████████| 3/3 [01:22<00:00, 27.52s/it]
Shard 1, Slice 3, Loss: 0.86, Accuracy: 67.22%
Shard: 2 | Training Slice 4/5
Epoch 4/5: 100%|██████████| 4/4 [02:21<00:00, 35.43s/it]
Shard 1, Slice 4, Loss: 0.41, Accuracy: 86.18%
Shard: 2 | Training Slice 5/5
Epoch 5/5: 100%|██████████| 5/5 [03:38<00:00, 43.72s/it]
Shard 1, Slice 5, Loss: 0.20, Accuracy: 93.52%

'''

