In [None]:
import os
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torch.nn.utils import clip_grad_norm_

import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR


# Set environment variables for PyTorch
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # Specify which GPU to use, if needed

# PyTorch does not have direct equivalents for some TensorFlow environment settings,
# but you can manage GPU memory growth and logging through PyTorch's API.

# Disable debug information (PyTorch does not have a direct equivalent, but you can manage logging)
# PyTorch logging can be managed through Python's logging module or by setting verbosity levels.

# Check if CUDA is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu"
print(device)

# Print PyTorch version
print(torch.__version__)

# Print number of available GPUs
print("Num GPUs Available: ", torch.cuda.device_count())

# List local devices (PyTorch does not have a direct equivalent, but you can check CUDA devices)
for i in range(torch.cuda.device_count()):
    print(f"Device {i}: {torch.cuda.get_device_name(i)}")

In [None]:
import cv2
import random
import glob
import math

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
HEIGHT, WIDTH = 96, 256
NUM_SAMPLES = 100
NUM_CLASSES = 7

BATCH_SIZE = 4
ACCUMULATION_STEPS = 8

In [None]:
class SampleDataset(Dataset):
    def __init__(self, num_samples, num_classes, height, width):
        super().__init__()
        self.num_samples = num_samples
        self.input_data = torch.rand(num_samples, 3, height, width)  # Input data in the range [0, 1]
        self.output_data = torch.randint(0, num_classes+1, (num_samples, height, width))  # Output data with integers from 0 to 12

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.input_data[idx], self.output_data[idx]

In [None]:
# train_gen = DataGenerator(train_folder)
train_gen = SampleDataset(NUM_SAMPLES, NUM_CLASSES, HEIGHT, WIDTH)
val_gen = SampleDataset(NUM_SAMPLES, NUM_CLASSES, HEIGHT, WIDTH)

train_dataloader = DataLoader(train_gen, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_gen, batch_size=BATCH_SIZE, shuffle=False)

imgs, segs = next(iter(train_dataloader))

print(imgs.shape, segs.shape)
print(np.unique(segs))

imgs = imgs.to(device)
segs = segs.to(device)

In [None]:
class SimpleSegmentationModel(nn.Module):
    def __init__(self, num_classes):
        super(SimpleSegmentationModel, self).__init__()
        
        # Define the network architecture
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Define final convolutional layer for classification
        self.final_conv = nn.Conv2d(256, num_classes, kernel_size=1)
        
        # Define upsampling layers
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # Out: (batch, 64, height//2, width//2)
        x = self.pool(F.relu(self.conv2(x)))  # Out: (batch, 128, height//4, width//4)
        x = F.relu(self.conv3(x))             # Out: (batch, 256, height//4, width//4)
        
        x = self.upsample1(x)                 # Up: (batch, 256, height//2, width//2)
        x = self.upsample2(x)                 # Up: (batch, 256, height, width)
        
        x = self.final_conv(x)                # Out: (batch, num_classes, height, width)
        
        return x
    
    
model = SimpleSegmentationModel(num_classes=NUM_CLASSES)
model.to(device) # Move the model to the appropriate device (GPU if available)

In [None]:
def iou_metric(y_true, y_pred, smooth=1, num_classes=NUM_CLASSES):
    # Flatten y_true
    y_true = y_true.reshape(-1).float()  # Ensure y_true is float32 and flatten

    # Reshape y_pred to match
    y_pred = y_pred.permute(0, 2, 3, 1).reshape(-1,num_classes).float()  # (batch_size * height * width, num_classes)

    iou_per_class = []
    for class_id in range(num_classes):
        true_mask = (y_true == class_id).float()  # Binary mask for class
        pred_mask = y_pred[:, class_id]  # Softmax probability for class

        intersection = torch.sum(true_mask * pred_mask)
        union = torch.sum(true_mask) + torch.sum(pred_mask) - intersection

        iou = (intersection + smooth) / (union + smooth)
        iou_per_class.append(iou)  # Store IoU for each class

    return torch.mean(torch.stack(iou_per_class))

def jaccard_distance_loss(y_true, y_pred, smooth=1, num_classes=NUM_CLASSES):
    iou = iou_metric(y_true, y_pred, smooth=smooth, num_classes=num_classes)
    return 1 - iou

In [None]:
EPOCHS = [20,50]
learning_rate = 1e-4

In [None]:
# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # Start with small LR

In [None]:
max_grad_norm = 1.0

for epoch in range(EPOCHS[1]):
    
    train_iou, train_loss = 0, 0
    model.train()  # Set the model to training mode
    for batch_idx, (imgs, segs) in enumerate(train_dataloader):
        imgs = Variable(imgs).to(device)
        segs = Variable(segs).to(device)
        
        output = model(imgs)#, training=True)
        # print(output.shape)
        loss = jaccard_distance_loss(segs, output)
        train_loss += loss.item()
        # print("after loss: ", loss)
        # loss.backward(inputs=tuple(model.parameters()), retain_graph=True) # <------------------ ????
        loss.backward() # <------------------ ????
        # print("after backward")
        
        if (batch_idx + 1) % ACCUMULATION_STEPS == 0:
            clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()        
    
    
    # Validation phase
    val_iou, val_loss = 0, 0
    # model.eval()  # Set the model to evaluation mode
    # with torch.no_grad():
    #     for imgs, segs in val_dataloader:
    #         output = model(imgs, training=False)
    #         loss = jaccard_distance_loss(segs, output)
    #         val_loss += loss.item()
    #         val_iou += (1.0 - loss.item())
    print(f"Epoch {epoch + 1}, iou: {train_iou:.4f}, loss: {train_loss:.4f}, val_iou: {val_iou:.4f}, val_loss: {val_loss:.4f}")