In [1]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import STL10
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm.notebook import tqdm

transform = transforms.Compose([
    transforms.ToTensor()
])

unlabeled_dataset = STL10(root='../data', split='unlabeled', download=True, transform=transform)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=64, shuffle=True)

Files already downloaded and verified


In [2]:
from torchvision.models import resnet18

def get_resnet_backbone():
    model = resnet18(pretrained=True)
    in_features = model.fc.in_features
    model.fc = torch.nn.Identity()
    return model, in_features

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim, use_bn=False):
        super(DINOHead, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 2048),
            nn.GELU(),
            nn.Linear(2048, out_dim)
        )

    def forward(self, x):
        return self.mlp(x)

class DINOLoss(nn.Module):
    def __init__(self, out_dim):
        super(DINOLoss, self).__init__()
        self.register_buffer("center", torch.zeros(1, out_dim))
        self.teacher_temp = 0.04
        self.student_temp = 0.1
        self.center_momentum = 0.9

    def forward(self, student_output, teacher_output):
        teacher_output = F.softmax((teacher_output - self.center) / self.teacher_temp, dim=-1).detach()
        student_output = F.log_softmax(student_output / self.student_temp, dim=-1)

        loss = torch.mean(torch.sum(-teacher_output * student_output, dim=-1))

        self.update_center(teacher_output)
        return loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        batch_center = torch.mean(teacher_output, dim=0, keepdim=True)
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

class DINO(nn.Module):
    def __init__(self, backbone, head):
        super(DINO, self).__init__()
        self.backbone = backbone
        self.head = head

    def forward(self, x):
        features = self.backbone(x)
        return self.head(features)

backbone, in_feat = get_resnet_backbone()
student_head = DINOHead(in_dim=in_feat, out_dim=65536)
teacher_head = DINOHead(in_dim=in_feat, out_dim=65536)

student = DINO(backbone, student_head)
teacher = DINO(backbone, teacher_head)

for param in teacher.parameters():
    param.requires_grad = False

criterion = DINOLoss(out_dim=65536)



In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student = student.to(device)
teacher = teacher.to(device)
criterion = criterion.to(device)
optimizer = optim.Adam(student.parameters(), lr=0.001)
num_epochs = 5

for epoch in range(num_epochs):
    for images, _ in tqdm(unlabeled_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False):
        images = images.to(device)

        student_output = student(images)
        with torch.no_grad():
            teacher_output = teacher(images)

        loss = criterion(student_output, teacher_output)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/5]:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [1/5], Loss: 6.1964


Epoch [2/5]:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [2/5], Loss: 5.8126


Epoch [3/5]:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [3/5], Loss: 5.4330


Epoch [4/5]:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [4/5], Loss: 6.1183


Epoch [5/5]:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [5/5], Loss: 5.4700
