In [1]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, ignore_dir, feature_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.ignore_dir = ignore_dir
        self.feature_dir = feature_dir
        self.transform = transform

        self.img_names = sorted(os.listdir(img_dir))
        self.mask_names = sorted(os.listdir(mask_dir))
        self.ignore_names = sorted(os.listdir(ignore_dir))
        self.feature_names = sorted(os.listdir(feature_dir))

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        print(idx)
        feature_path = os.path.join(self.feature_dir, self.feature_names[idx])
        features = torch.load(feature_path).transpose(2,0,1)

        img_path = os.path.join(self.img_dir, self.img_names[idx])
        img = Image.open(img_path).convert('RGB')
        img = np.array(img.resize((img.width//14, img.height//14)))
        
        mask_path = os.path.join(self.mask_dir, self.mask_names[idx])
        mask = np.array(Image.open(mask_path).convert("L"))/255.0
        mask=mask[np.newaxis, :, :]
        mask = torch.tensor(mask, dtype=torch.float).cuda()
        
        ignore_path = os.path.join(self.ignore_dir, self.ignore_names[idx])
        with open(ignore_path, 'rb') as f:
            ignore_coords = pickle.load(f)


        return {'image': img, 'features': features, 'mask': mask, 'ignore_coords': ignore_coords}

In [2]:
# Training loop
def train_model(model, dataloader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in dataloader:
            features = batch['features']
            masks = batch['mask']

            for i in range(features.shape[2]):
                for j in range(features.shape[3]):
                    features = features[:, :, i, j]
                    masks = masks[:, :, i, j]
                    
                    print(features.shape)   #torch.Size([1, 1536, 215, 176])  CHANGING SHAPES RN want, (1536)
                    print(masks.shape)#torch.Size([1, 1, 214, 176]) want, (2996, 2464)

                        #will slice th emap into 14 x14 pieces and then flatten them, that is ground truth output for that section
                    # Flatten the features to match the input shape of the model
                    features = features.view(features.size(0), -1)

                    # Forward pass
                    outputs = model(features)
                    loss = criterion(outputs, masks)

                    # Backward pass and optimization
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    running_loss += loss.item()
        
        epoch_loss = running_loss / len(dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')
    print("Training complete.")

    def calculate_loss(self, image, mask, ignore_coords):
        logits_mask = self.forward(image)

        # Create a mask of the same size as the image, with 1s at the coordinates to ignore and 0s elsewhere
        ignore_mask = torch.zeros_like(mask).cuda()
        for y, x in ignore_coords:
            y=y.item()
            x=x.item()  
            ignore_mask[:, :, x, y] = 1

        logits_mask = logits_mask * (1 - ignore_mask)
        mask = mask * (1 - ignore_mask)

        # Calculate the loss
        loss = self.loss_fn(logits_mask, mask)

        return loss, logits_mask

In [3]:

# Define the neural network
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(1536, 196)  # Fully connected layer to map input to 196 units
        self.relu = nn.ReLU()  # Activation function
        self.reshape = lambda x: x.view(-1, 14, 14)  # Reshape output to (14, 14)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.reshape(x)
        return x

# Example usage
if __name__ == "__main__":
    img_dir = '/pasteur/u/aunell/cryoViT/data/sample_data/original'
    mask_dir = '/pasteur/u/aunell/cryoViT/data/training/mask'
    ignore_dir = '/pasteur/u/aunell/cryoViT/data/training/ignore'
    feature_dir = '/pasteur/u/aunell/cryoViT/data/training/features'

    dataset = SegmentationDataset(img_dir, mask_dir, ignore_dir, feature_dir)
    print('len dataset:', len(dataset))
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)

    model = MyModel()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    train_model(model, dataloader, criterion, optimizer, num_epochs=10)

len dataset: 7


2
1
06

