In [14]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.optim as optim
import wandb

# Custom dataset class to load data from tensor files
class CustomDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.file_list = [f for f in os.listdir(folder_path) if f.endswith('.pt')]
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        file_path = os.path.join(self.folder_path, file_name)
        data = torch.load(file_path)
        image = data[0] 
        cell_count = torch.tensor(data[1]).float()

        # Apply transformation to resize the image to 120x120
        if self.transform:
            image = self.transform(image)

        return {'image': image, 'cell_count': cell_count}

# Define the U-Net model with a 3-layer MLP for regression
class SegmentationAndRegressionModel(nn.Module):
    def __init__(self, encoder_name="efficientnet-b0", encoder_weights="imagenet"):
        super(SegmentationAndRegressionModel, self).__init__()
        self.unet = smp.Unet(encoder_name, encoder_weights=encoder_weights, in_channels=1, classes=1)
        for param in self.unet.parameters(): #freeze params for speed
            param.requires_grad = False

        self.mlp = nn.Sequential(
            nn.Linear(16384, 512), 
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        seg_output = self.unet(x)
        mlp_input = seg_output.view(seg_output.size(0), -1)  # Flatten
        regression_output = self.mlp(mlp_input)
        return regression_output

# Function to evaluate the model on the validation set
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in dataloader:
            inputs, labels = batch['image'].to(device), batch['cell_count'].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
    average_loss = total_loss / len(dataloader)
    return average_loss

# Function to train the model
def train_model(model, train_dataloader, test_dataloader, criterion, optimizer, device, num_epochs=150):
    model.train()
    model.to(device)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in train_dataloader:
            inputs, labels = batch['image'].to(device), batch['cell_count'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        epoch_loss = running_loss / len(train_dataloader)

        # Log training loss to WandB
        wandb.log({"Training Loss": epoch_loss})

        # Evaluate the model on the validation set
        val_loss = evaluate_model(model, test_dataloader, criterion, device)

        # Log validation loss to WandB
        wandb.log({"Validation Loss": val_loss})

        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}')

if __name__ == "__main__":
    # Set your folder path containing the PyTorch tensor files
    train_data_folder = "../Data/input_tensors/train"
    test_data_folder = "../Data/input_tensors/val"

    transform = transforms.Compose([
        transforms.ToPILImage(),  
        transforms.Resize((128, 128)),  # Resize the image for consistent input shape
        transforms.ToTensor(), 
    ])

    # Create custom datasets for train and test
    train_dataset = CustomDataset(folder_path=train_data_folder, transform=transform)
    test_dataset = CustomDataset(folder_path=test_data_folder, transform=transform)

    # Create DataLoader for train and test datasets
    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

    # Initialize WandB
    wandb.init(project="Count_cells", entity="turlagheoin")  # allow to show on wandb

    # Initialize the model
    model = SegmentationAndRegressionModel()

    # Check if GPU is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Train the model
    train_model(model, train_dataloader, test_dataloader, criterion, optimizer, device)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Training Loss,▁
Validation Loss,▁

0,1
Training Loss,55546.69125
Validation Loss,3292.66591


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888925108, max=1.0…

Epoch 1/150, Training Loss: 56644.2733, Validation Loss: 2536.9677
Epoch 2/150, Training Loss: 55171.0372, Validation Loss: 415.7535
Epoch 3/150, Training Loss: 54429.0211, Validation Loss: 408.9879
Epoch 4/150, Training Loss: 53843.8567, Validation Loss: 450.4972
Epoch 5/150, Training Loss: 52149.8096, Validation Loss: 2453.1691
Epoch 6/150, Training Loss: 50835.2037, Validation Loss: 599.6057
Epoch 7/150, Training Loss: 52760.2496, Validation Loss: 3156.8141
Epoch 8/150, Training Loss: 52045.5743, Validation Loss: 649.4500
Epoch 9/150, Training Loss: 51739.8051, Validation Loss: 531.7023
Epoch 10/150, Training Loss: 51520.6210, Validation Loss: 4608.3844
Epoch 11/150, Training Loss: 51556.1190, Validation Loss: 424.8141
Epoch 12/150, Training Loss: 51585.3427, Validation Loss: 292.5317
Epoch 13/150, Training Loss: 51417.7870, Validation Loss: 281.2149
Epoch 14/150, Training Loss: 51710.4152, Validation Loss: 286.1792
Epoch 15/150, Training Loss: 51608.1257, Validation Loss: 578.4853


KeyboardInterrupt: 

In [15]:
# Save the model for later use once training is done

torch.save(model.state_dict(), 'countingModel.pth')