<div align="left">

<table>
  <thead>
    <tr>
      <th>Part</th>
      <th>Shape</th>
      <th>Explanation</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Input to model</td>
      <td>(6, 128, 128)</td>
      <td>Stack of two (3-channel) segmented images</td>
    </tr>
    <tr>
      <td>Ground truth</td>
      <td>(3, 128, 128)</td>
      <td>Background, liver, tumor mask</td>
    </tr>
    <tr>
      <td>Model output</td>
      <td>(3, 128, 128)</td>
      <td>Predicted background, liver, tumor</td>
    </tr>
  </tbody>
</table>

</div>


In [2]:
import os
from torch.utils.data import Dataset
from PIL import Image
import torch
# import torchvision.transforms as T
class MultiInputSegmentationDataset(Dataset):
    def __init__(self, folder1, folder2, gt_folder, transform=None):
        self.folder1 = folder1
        self.folder2 = folder2
        self.gt_folder = gt_folder
        self.transform = transform
        self.image_names = os.listdir(folder1)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img1_path = os.path.join(self.folder1, self.image_names[idx])
        img2_path = os.path.join(self.folder2, self.image_names[idx])
        gt_path = os.path.join(self.gt_folder, self.image_names[idx])

        # No RGB conversion here
        img1 = Image.open(img1_path)
        img2 = Image.open(img2_path)
        gt = Image.open(gt_path)

        if self.transform:
            img1 = self.transform(img1)  # (3, 128, 128) if already 3 channels
            img2 = self.transform(img2)
            gt = self.transform(gt)      # (3, 128, 128)

        # Stack img1 and img2 along channel dimension -> (6, 128, 128)
        input_image = torch.cat((img1, img2), dim=0)

        return input_image, gt


In [None]:
# transform = T.Compose([
#     T.Resize((128, 128)),
#     T.ToTensor(),  # Converts to [0,1] tensor and (C, H, W) format
# ])


In [1]:
import segmentation_models_pytorch as smp

# Define model
model = smp.Unet(
    encoder_name="resnet34",        # Pretrained on ImageNet
    encoder_weights="imagenet",
    in_channels=6,                  # 6 input channels (folder1 + folder2 stacked)
    classes=3,                      # 3 output classes
)

model = model.to(device)


ModuleNotFoundError: No module named 'segmentation_models_pytorch'

In [None]:
# (folder1 3-channel image) + (folder2 3-channel image)
#     ---> [Stacked to (6, 128, 128)]
#         ---> Model
#             ---> (3, 128, 128) output
#                 ---> Loss with Ground Truth


In [None]:
import torch
import segmentation_models_pytorch as smp

loss_fn = smp.losses.DiceLoss(mode='multiclass')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [None]:
from torch.utils.data import DataLoader

# Dataset and Dataloader
train_dataset = MultiInputSegmentationDataset(folder1, folder2, gt_folder, transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Training loop
num_epochs = 30

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader)}")
