In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, ConcatDataset
from torchvision.datasets import ImageFolder
from torch import nn, optim
from sklearn.metrics import accuracy_score

# Paths to datasets
data_dir_1 = './train'
data_dir_2 = './cell_images'

# Data augmentation and normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load and split datasets
dataset_1 = ImageFolder(root=data_dir_1, transform=transform)
train_size = int(0.8 * len(dataset_1))
val_size = len(dataset_1) - train_size
train_dataset_1, val_dataset_1 = torch.utils.data.random_split(dataset_1, [train_size, val_size])

dataset_2 = ImageFolder(root=data_dir_2, transform=transform)
train_size_2 = int(0.8 * len(dataset_2))
val_size_2 = len(dataset_2) - train_size_2
train_dataset_2, val_dataset_2 = torch.utils.data.random_split(dataset_2, [train_size_2, val_size_2])

# Combine datasets
combined_train_dataset = ConcatDataset([train_dataset_1, train_dataset_2])
combined_val_dataset = ConcatDataset([val_dataset_1, val_dataset_2])

# DataLoaders
train_loader = DataLoader(combined_train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(combined_val_dataset, batch_size=32, shuffle=False)

# Define the model architectures
def initialize_model(model_name, num_classes=1):
    if model_name == "resnet50":
        model = torchvision.models.resnet50(pretrained=True)
        model.fc = nn.Sequential(
            nn.Dropout(0.8),
            nn.Linear(model.fc.in_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.8),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.8),
            nn.Linear(512, num_classes)
        )
    elif model_name == "resnet18":
        model = torchvision.models.resnet18(pretrained=True)
        model.fc = nn.Sequential(
            nn.Linear(model.fc.in_features, num_classes)
        )
    elif model_name == "vgg16":
        model = torchvision.models.vgg16(pretrained=True)
        model.classifier[6] = nn.Linear(4096, num_classes)
    elif model_name == "densenet121":
        model = torchvision.models.densenet121(pretrained=True)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    elif model_name == "efficientnet_b0":
        model = torchvision.models.efficientnet_b0(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    else:
        raise ValueError("Model name not recognized")

    return model

# Training and evaluation function
def train_and_evaluate(model, train_loader, val_loader, optimizer, num_epochs=10, patience=3):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs).squeeze()
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            preds = torch.round(torch.sigmoid(outputs))
            correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct_preds / total_preds

        # Validation step
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs).squeeze()
                loss = criterion(outputs, labels.float())
                val_loss += loss.item() * inputs.size(0)

        val_loss /= len(val_loader.dataset)

        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Accuracy: {epoch_acc:.4f}")
        print(f"Validation Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered")
                break

    return best_val_loss

# Train and evaluate all models
model_names = ["resnet50", "resnet18", "vgg16", "densenet121", "efficientnet_b0"]
for model_name in model_names:
    print(f"\nTraining {model_name.upper()}...")
    model = initialize_model(model_name)
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-4)
    train_and_evaluate(model, train_loader, val_loader, optimizer)



Training RESNET50...


KeyboardInterrupt: 