In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import transforms, models
from PIL import Image
import numpy as np
import os
import random
from tqdm import tqdm
from datasets import load_dataset

ModuleNotFoundError: No module named 'torch'

In [3]:
!pip install datasets



In [4]:
# Set random seed for reproducibility -> maybe use pytorch lightning for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
architecture = 'vit'

<torch._C.Generator at 0x796de6f0d0d0>

In [5]:
tinyImageNet_dataset = load_dataset("zh-plus/tiny-imagenet")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [6]:
tinyImageNet_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 100000
    })
    valid: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

In [8]:
def rotate_img(img, rot):
    if rot == 0:  # 0 degrees rotation
        return img
    elif rot == 90:  # 90 degrees rotation
        return np.flipud(np.transpose(img, (1, 0, 2)))
    elif rot == 180:  # 180 degrees rotation
        return np.fliplr(np.flipud(img))
    elif rot == 270:  # 270 degrees rotation
        return np.transpose(np.flipud(img), (1, 0, 2))
    else:
        raise ValueError('Rotation should be 0, 90, 180, or 270 degrees.')

class RotationDataset(data.Dataset):
    def __init__(self, hf_dataset, transform=None, architecture='resnet'):
        """
        Input:
            hf_dataset: HuggingFace Dataset object.
            transform: Optional transform to be applied on a sample.
        """
        self.dataset = hf_dataset
        self.transform = transform
        self.rotations = [0, 90, 180, 270]
        self.architecture = architecture

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Load image from the HuggingFace dataset and convert to RGB
        image = self.dataset[idx]['image'].convert('RGB')  # Ensure image is in RGB

        # We'll resize it to 255x255 since this is ResNet's input size
        image = image.resize((255, 255)) if self.architecture == 'resnet' else image.resize((224, 224))

        # if self.transform:
        #     image = self.transform(image)

        # Create four rotated versions of the image and corresponding labels (0, 1, 2, 3 for 0°, 90°, 180°, 270°)
        rotated_imgs = []
        for rot in self.rotations:
            rotated_image = rotate_img(np.array(image), rot)  # Apply rotation
            rotated_image = Image.fromarray(rotated_image)    # Convert back to PIL Image
            rotated_image = self.transform(rotated_image)     # Apply transformations
            rotated_imgs.append(rotated_image)
        rotation_labels = torch.LongTensor([0, 1, 2, 3])

        # Stack the rotated images into a tensor
        rotated_imgs_tensor = torch.stack(rotated_imgs, dim=0)  # Shape: [4, 3, H, W] for 4 rotations

        return rotated_imgs_tensor, rotation_labels


In [9]:
transform = transforms.Compose([transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),transforms.ToTensor()])

In [10]:
class RotationNet(nn.Module):
    def __init__(self,
                 n_rotations=4,  # 4 rotations: 0°, 90°, 180°, 270°
                 architecture = 'resnet', # 'resnet' or 'vit'
                ):

        super(RotationNet, self).__init__()

        if architecture=='resnet':
            # Backbone ResNet model TODO: replace by ResNet 50
            # self.resnet = models.resnet18(pretrained=False) # I thnk this is deprecated
            self.backbone = models.resnet18()
            self.backbone.fc = nn.Identity() #Remove the classification layer
            feature_dim = 512

        elif architecture=='vit':
            # Backbone ViT model
            self.backbone = models.vit_b_16(pretrained=False)
            self.backbone.heads = nn.Identity()  # Remove the classification head
            feature_dim = 768  # Feature dimension for ViT-B_16


        # Fully connected layers << to dispose after the PTT
        self.fc = nn.Sequential(
            nn.Linear(feature_dim, 128),  # TODO not sure abt dims?
            nn.ReLU(),
            nn.Linear(128, n_rotations)
        )

    def forward(self, x):
        # x shape: [batch_size, 3, 64, 64]
        features = self.backbone(x)  # Shape: [batch_size, feature_dim]
        out = self.fc(features)  # Shape: [batch_size, n_rotations]
        return out

In [11]:
# model = RotationNet(n_rotations=4)
model = RotationNet(n_rotations=4, architecture=architecture)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


In [12]:

# Create the datasets and dataloaders
train_dataset = RotationDataset(tinyImageNet_dataset['train'], transform=transform, architecture=architecture)
valid_dataset = RotationDataset(tinyImageNet_dataset['valid'], transform=transform, architecture=architecture)

batch_size = 128 if architecture == 'resnet' else 64

train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)



In [13]:
print('Shape of dataset output: {}'.format(next(iter(train_loader))[0].shape))


Shape of dataset output: torch.Size([128, 4, 3, 255, 255])


In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
num_epochs = 10
log_each = 100

In [None]:
hw = 255 if architecture == 'resnet' else 224
for epoch in range(num_epochs):
    model.train()
    avg_loss = 0.0
    for batch_idx, (rotated_imgs, rotation_labels) in enumerate(tqdm(train_loader)):
        rotated_imgs = rotated_imgs.view(-1, 3, hw, hw).to(device)  # Shape: [batch_size * 4, 3, 255, 255]
        rotation_labels = rotation_labels.view(-1).to(device)  # Shape: [batch_size * 4]



        optimizer.zero_grad()
        outputs = model(rotated_imgs)  # Shape: [batch_size * 4, n_rotations]
        loss = criterion(outputs, rotation_labels)  # Shape: [batch_size * 4]
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()
        if batch_idx % log_each == log_each - 1:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {avg_loss / 100:.4f}')
            avg_loss = 0.0

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for rotated_imgs, rotation_labels in tqdm(valid_loader):
            rotated_imgs = rotated_imgs.view(-1, 3, hw, hw).to(device)  # Shape: [batch_size * 4, 3, 255, 255]
            rotation_labels = rotation_labels.view(-1).to(device)  # Shape: [batch_size * 4]

            outputs = model(rotated_imgs)  # Shape: [batch_size * 4, n_rotations]
            loss = criterion(outputs, rotation_labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += rotation_labels.size(0)
            correct += (predicted == rotation_labels).sum().item()

    val_accuracy = 100 * correct / total
    avg_val_loss = val_loss / len(valid_loader)
    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')


  0%|          | 0/782 [00:00<?, ?it/s]