In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import random
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [3]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [5]:
class SimCLRDataTransform(object):
    def __init__(self, size):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=size), #Randomly crop images
            transforms.RandomHorizontalFlip(), #Randomly flip images horizontally
            transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), #Perform color distortion and grayscale conversion
            transforms.RandomGrayscale(p=0.2), #Perform color distortion and grayscale conversion
            transforms.GaussianBlur(kernel_size=int(0.1 * size)), #Apply Gaussian Blur
            transforms.ToTensor(), #Convert image to Tensor and normalize
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #Convert image to Tensor and normalize
        ])

    def __call__(self, x):
        return self.transform(x)

In [6]:
from torch.utils.data import Dataset

class SimCLRDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]

        if self.transform:
            image1 = self.transform(image)
            image2 = self.transform(image)
            return image1, image2
        return image, image

# Use custom dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
trainset = SimCLRDataset(trainset, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

Files already downloaded and verified


In [7]:
class ResNetEncoder(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetEncoder, self).__init__()
        resnet = models.resnet50(pretrained=pretrained)
        # Remove the last fully connected layer
        self.features = nn.Sequential(*list(resnet.children())[:-1])

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        return x

In [8]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=512, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [9]:
class SimCLR(nn.Module):
    def __init__(self, encoder, projection_head):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.projection_head = projection_head

    def forward(self, x):
        x = self.encoder(x)
        x = self.projection_head(x)
        return x

In [10]:
encoder = ResNetEncoder()
projection_head = ProjectionHead()
model = SimCLR(encoder, projection_head)

In [11]:
# def nt_xent_loss(features, temperature=0.5):
#     """
#     计算NT-Xent Loss

#     :param features: 特征向量，大小为 [2 * batch_size, feature_dim]
#     :param temperature: 温度参数
#     :return: 损失值
#     """
#     device = features.device
#     batch_size, _ = features.shape
#     labels = torch.cat([torch.arange(batch_size // 2) for _ in range(2)], dim=0)
#     labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float().to(device)

#     features = F.normalize(features, dim=1)

#     similarity_matrix = torch.matmul(features, features.T)

#     # 排除对角线元素
#     mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
#     labels = labels[~mask].view(labels.shape[0], -1)
#     similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

#     # 使用交叉熵损失计算
#     positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
#     negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

#     logits = torch.cat([positives, negatives], dim=1)
#     labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

#     logits = logits / temperature
#     return F.cross_entropy(logits, labels)




def NT_XEnt(interleaved_inputs, temperature=0.5):
    length = interleaved_inputs.shape[0]
    
    xcs = F.cosine_similarity(interleaved_inputs[None,:,:], interleaved_inputs[:,None,:], dim=-1)
    eye = torch.eye(length)
    y = xcs.clone()
    y[eye.bool()] = float("-inf")
    y = y / temperature
    
    target = torch.arange(length)
    target[0::2] += 1
    target[1::2] -= 1
    
    index = target.reshape(length, 1).long()
    
    ground_truth_labels = torch.zeros(8,8).long()
    src = torch.ones(length, length).long()
    ground_truth_labels = torch.scatter(ground_truth_labels, 1, index, src)
    
    return F.cross_entropy(y, target, reduction="mean")


In [None]:
import torch.optim as optim


batch_size = 256  # Adjust according to your GPU
epochs = 100  # Adjust to your needs
# Define learning rate and momentum
learning_rate = 0.1  # According to the paper, the initial learning rate is set to 0.1, but you can adjust it according to the actual situation
momentum = 0.9       # Momentum is usually set to 0.9

# Instantiate the SGD optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

print("optimizer")
# Assuming trainloader is your data loader
for epoch in range(epochs):
    model.train()  # Set the model to training mode
    total_loss = 0
    for batch in trainloader:
        # Assume that the trainloader returns a pair of enhanced image sets
        images1, images2 = batch

        # Make sure images1 and images2 are both four-dimensional tensors
        # Shape should be [batch_size, channels, height, width]

        concatenated_images = torch.cat((images1, images2), dim=0)  # 在批次维度上合并

        optimizer.zero_grad()  # 梯度归零

        features = model(concatenated_images)  # 获取特征表示
        loss = NT_XEnt(features)  # 计算NT-Xent Loss

        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重

        total_loss += loss.item()
    
    avg_loss = total_loss / len(trainloader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

torch.save(model.state_dict(), 'simclr_model.pth')
print("Done")

optimizer
