In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import os
import torch.nn.functional as F

In [2]:
class ChestXRayDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.images, self.labels = self._load_data()

    def _load_data(self):
        images, labels = [], []
        class_folders = sorted(os.listdir(self.data_path))
        for label, class_folder in enumerate(class_folders):
            class_path = os.path.join(self.data_path, class_folder)
            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)
                images.append(image_path)
                labels.append(label)
        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        label = self.labels[index]

        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

In [3]:
class ChestXRayModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.fc_downstream = nn.Linear(resnet.fc.in_features, num_classes)

    def forward(self, x):
        features = self.resnet(x).squeeze()
        downstream_output = self.fc_downstream(features)
        return downstream_output

In [19]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super().__init__()
        self.temperature = temperature

    def forward(self, output1, output2, labels):
        # Ensure that labels have the right shape for the cosine_similarity
        labels = labels.unsqueeze(1).expand(-1, 2).flatten()

        # Normalize the embeddings
        output1 = F.normalize(output1, p=2, dim=-1)
        output2 = F.normalize(output2, p=2, dim=-1)

        cosine_similarity = F.cosine_similarity(output1, output2, dim=-1)

        # Print sizes for debugging
        print("Output1 Size:", output1.size())
        print("Output2 Size:", output2.size())
        print("Labels Size:", labels.size())

        # Ensure the correct dimensions for labels
        labels = labels[:output1.size(0)]

        # Initialize loss to a default value
        loss = torch.tensor(0.0, requires_grad=True)

        # Calculate the contrastive loss
        loss = torch.mean((1 - labels) * torch.pow(cosine_similarity, 2) / self.temperature ** 2 +
                          labels * torch.exp(-cosine_similarity / self.temperature))

        return loss


In [22]:
data_path = "/content/drive/MyDrive/Covid-Image"
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

batch_size = 3
num_epochs_contrastive = 10
learning_rate_contrastive = 0.001

chest_xray_dataset = ChestXRayDataset(data_path, transform=transform)
chest_xray_dataloader = DataLoader(chest_xray_dataset, batch_size=batch_size, shuffle=True)
model = ChestXRayModel(num_classes=len(os.listdir(data_path)))
contrastive_loss = ContrastiveLoss()
optimizer_contrastive = optim.Adam(model.parameters(), lr=learning_rate_contrastive)

# Block 4: Contrastive Learning Phase
for epoch in range(num_epochs_contrastive):
    print(f"Contrastive Learning - Epoch {epoch + 1}/{num_epochs_contrastive}")
    for batch_images, batch_labels in chest_xray_dataloader:
        batch_images_augmented = torch.cat([batch_images, batch_images.flip(0)], dim=0)
        batch_labels_augmented = torch.cat([batch_labels, batch_labels], dim=0)

        optimizer_contrastive.zero_grad()
        contrastive_outputs = model(batch_images_augmented)

        # Print sizes for debugging
        print("Batch Images Size:", batch_images.size())
        print("Augmented Batch Images Size:", batch_images_augmented.size())
        print("Contrastive Outputs Size:", contrastive_outputs.size())

        # Check if there is a size mismatch
        if contrastive_outputs.size(0) != 2 * batch_size:
            print("Size Mismatch Error: Output size doesn't match expected size.")
            continue  # Skip this iteration to avoid the error

        loss = contrastive_loss(
            contrastive_outputs[:batch_size],
            contrastive_outputs[batch_size:],
            batch_labels_augmented.float()
        )

        # Aggregate the loss to a scalar value
        loss = torch.mean(loss)

        loss.backward()
        optimizer_contrastive.step()

        print(f"Batch loss: {loss.item()}")


Batch Images Size: torch.Size([3, 3, 224, 224])
Augmented Batch Images Size: torch.Size([6, 3, 224, 224])
Contrastive Outputs Size: torch.Size([6, 1])
Output1 Size: torch.Size([3, 1])
Output2 Size: torch.Size([3, 1])
Labels Size: torch.Size([12])
Batch loss: 100.0
Batch Images Size: torch.Size([3, 3, 224, 224])
Augmented Batch Images Size: torch.Size([6, 3, 224, 224])
Contrastive Outputs Size: torch.Size([6, 1])
Output1 Size: torch.Size([3, 1])
Output2 Size: torch.Size([3, 1])
Labels Size: torch.Size([12])
Batch loss: 100.0
Batch Images Size: torch.Size([3, 3, 224, 224])
Augmented Batch Images Size: torch.Size([6, 3, 224, 224])
Contrastive Outputs Size: torch.Size([6, 1])
Output1 Size: torch.Size([3, 1])
Output2 Size: torch.Size([3, 1])
Labels Size: torch.Size([12])
Batch loss: 100.0
Batch Images Size: torch.Size([3, 3, 224, 224])
Augmented Batch Images Size: torch.Size([6, 3, 224, 224])
Contrastive Outputs Size: torch.Size([6, 1])
Output1 Size: torch.Size([3, 1])
Output2 Size: torch.S

In [23]:
torch.save(model.resnet.state_dict(), "contrastive_backbone.pth")
print("Contrastive backbone (ResNet) saved successfully!")

Contrastive backbone (ResNet) saved successfully!
