<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Self_Supervised_Learning_(SSL).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install --upgrade torch torchvision

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

# Custom dataset for rotation pretext task
class RotationDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.rotation_angles = [0, 90, 180, 270]  # Define rotation angles

    def __len__(self):
        return len(self.dataset) * len(self.rotation_angles)

    def __getitem__(self, idx):
        # Determine the original image index and the rotation to apply
        image_idx = idx // len(self.rotation_angles)
        rotation_idx = idx % len(self.rotation_angles)
        angle = self.rotation_angles[rotation_idx]

        # Get the original image and apply the rotation
        image, _ = self.dataset[image_idx]
        rotated_image = transforms.functional.rotate(image, angle)

        # Return the rotated image and the label (rotation index)
        return rotated_image, rotation_idx

# Load CIFAR-10 dataset
original_dataset = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
rotation_dataset = RotationDataset(original_dataset)
dataloader = DataLoader(rotation_dataset, batch_size=32, shuffle=True)

# Define a pretext task model for rotation prediction
class PretextTask(nn.Module):
    def __init__(self, base_model):
        super(PretextTask, self).__init__()
        self.base_model = base_model
        self.fc = nn.Linear(base_model.fc.in_features, 4)  # Predict 4 rotation classes
        self.base_model.fc = nn.Identity()  # Remove original classification head

    def forward(self, x):
        x = self.base_model(x)
        x = self.fc(x)
        return x

# Initialize ResNet backbone using the new `weights` parameter
base_model = models.resnet18(weights=None)  # Replace `pretrained=False` with `weights=None`
model = PretextTask(base_model)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    model.train()
    epoch_loss = 0
    for images, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {epoch_loss / len(dataloader):.4f}')