In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np

In [6]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from PIL import Image

# Custom Dataset
class CTScanDataset(Dataset):
    def __init__(self, image_dir, label_dir):
        """
        image_dir: Path to the folder containing input images.
        label_dir: Path to the folder containing label images.
        """
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.image_names = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        # Load the image and label
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        label_path = os.path.join(self.label_dir, image_name)

        # Open the image and label, convert to NumPy arrays
        image = np.array(Image.open(image_path).convert("L"), dtype=np.float32) / 255.0
        label = np.array(Image.open(label_path).convert("L"), dtype=np.int64)

        # Flatten the image and label
        image = image.flatten()
        label = label.flatten()

        return torch.tensor(image), torch.tensor(label)

# Fully Connected Model
class FullyConnectedClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(FullyConnectedClassifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.fc(x)

# Hyperparameters
INPUT_SIZE = 640 * 640  # Flattened size of the image
NUM_CLASSES = 16  # Adjust based on the number of label classes
BATCH_SIZE = 8
EPOCHS = 10
LEARNING_RATE = 0.001

# Replace with your actual folder paths
image_dir = "../amos22/Train/input"
label_dir = "../amos22/Train/label"

dataset = CTScanDataset(image_dir, label_dir)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Model, Loss, Optimizer
model = FullyConnectedClassifier(INPUT_SIZE, NUM_CLASSES)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Training Loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for images, labels in dataloader:
        # Forward pass
        outputs = model(images)  # Shape: (batch_size * 640 * 640, num_classes)
        labels = labels.view(-1)  # Flatten labels to match output: (batch_size * 640 * 640)

        # Calculate loss
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{EPOCHS}], Loss: {total_loss / len(dataloader):.4f}")

print("Training complete!")

ValueError: Expected input batch_size (8) to match target batch_size (3276800).