In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet50
from torchvision.models import resnet18
from tqdm import tqdm

In [14]:
# Define the BYOL network architecture
class BYOL(nn.Module):
    def __init__(self, backbone):
        super(BYOL, self).__init__()
        self.backbone = backbone
        self.projector = nn.Sequential(
            nn.Linear(2048, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(),
            nn.Linear(4096, 256),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )
        self.predictor = nn.Linear(256, 256)

    def forward(self, x1, x2):
        y1 = self.projector(self.backbone(x1))
        y2 = self.projector(self.backbone(x2))
        z1 = self.predictor(y1)
        z2 = self.predictor(y2)
        return y1, y2, z1, z2

In [18]:
# Define the training loop for BYOL
def train_byol(model, dataloader, optimizer, device):
    criterion = nn.CosineSimilarity(dim=1)
    loss_sum = 0.0
    model.train()
    for x1, x2 in tqdm(dataloader, total=len(dataloader)):
        x1 = x1.to(device)
        x2 = x2.to(device)
        y1, y2, z1, z2 = model(x1, x2)
        loss = 1 - criterion(y1, z2.detach()).mean() + 1 - criterion(y2, z1.detach()).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
    return loss_sum / len(dataloader)

In [4]:
# Define the training loop for supervised learning
def train_supervised(model, dataloader, criterion, optimizer, device):
    loss_sum = 0.0
    correct = 0
    total = 0
    model.train()
    for inputs, targets in tqdm(dataloader, total=len(dataloader)):
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    accuracy = 100.0 * correct / total
    return loss_sum / len(dataloader), accuracy

In [5]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Set device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the dataset
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = CIFAR10(root='./data', train=False, transform=transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
# Split the dataset into labeled and unlabeled sets
labeled_indices = torch.randperm(len(train_dataset))[:1000]  # Adjust the number of labeled samples
unlabeled_indices = torch.randperm(len(train_dataset))[1000:]

# Create labeled and unlabeled dataloaders
labeled_dataloader = DataLoader(train_dataset, batch_size=128, sampler=labeled_indices, num_workers=4, pin_memory=True)
unlabeled_dataloader = DataLoader(train_dataset, batch_size=128, sampler=unlabeled_indices, num_workers=4, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

# Create BYOL model
#backbone = resnet50(pretrained=True)
backbone = resnet18(pretrained=True)
backbone.fc = nn.Identity()  # Remove the final fully connected layer
model = BYOL(backbone).to(device)

# Create optimizer for BYOL
optimizer_byol = optim.SGD(model.parameters(), lr=0.03, momentum=0.9, weight_decay=1e-6)

In [19]:
# BYOL pre-training
num_epochs_byol = 100
for epoch in range(num_epochs_byol):
    loss = train_byol(model, unlabeled_dataloader, optimizer_byol, device)
    print(f"BYOL Pre-training - Epoch {epoch+1}/{num_epochs_byol}, Loss: {loss}")

  0%|                                                   | 0/383 [00:00<?, ?it/s]

torch.Size([128])


  0%|                                                   | 0/383 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x512 and 2048x4096)

In [None]:
# Create fully connected classification head
classifier = nn.Linear(256, 10).to(device)

# Create optimizer for supervised learning
optimizer_supervised = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-6)

In [None]:
# Supervised fine-tuning
num_epochs_supervised = 50
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs_supervised):
    loss, accuracy = train_supervised(model.projector, labeled_dataloader, criterion, optimizer_supervised, device)
    print(f"Supervised Fine-tuning - Epoch {epoch+1}/{num_epochs_supervised}, Loss: {loss}, Accuracy: {accuracy}%")

# Evaluate the model on the test set
model.eval()
test_loss, test_accuracy = train_supervised(model.projector, test_dataloader, criterion, optimizer_supervised, device)
print(f"Test Loss: {test_loss}, Test Accuracy: {test_accuracy}%")

# Save the trained model
torch.save(model.state_dict(), "byol_model.pth")