## Exploring the Dataset

from google.colab import drive
drive.mount('/content/drive')

In [None]:
pip install tqdm

: 

In [None]:
import torch
print(torch.__version__)

: 

In [None]:
# Debug mode if i = 0
i = 1

: 

In [None]:
output_dir = 'output'
test_dir = 'C:\example\Test_studentversion\images'

if(i == 1):
    train_img_folder = 'C:/example/Train/Train/images'
    train_gt_folder = 'C:/example/Train/Train/labels'
else:
    train_img_folder = 'C:/example/Train/Train/labels_test'
    train_gt_folder = 'C:/example/Train/Train/images_test'

: 

In [None]:
import os
from time import sleep
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms  
from PIL import Image
from tqdm import tqdm

if(not os.path.exists(train_img_folder)):
    print('Folder not exists')
if(not os.path.exists(train_gt_folder)):
    print('Folder not exists')

print("Start Training.....")

# Custom dataset for image segmentation
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = os.listdir(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, img_name)#.replace('sat.jpg', 'mask.png'))
        mask_path = os.path.join(self.mask_dir, img_name.replace('mask.png', 'sat.jpg'))
        
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        if self.transform:
            image = self.transform['image'](image)
            mask = self.transform['mask'](mask)

        return image, mask, img_name
    
class TestDataset:
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.images = os.listdir(image_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        image = Image.open(img_path).convert('RGB')
        
        preprocess = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor()
        ])
        image = preprocess(image)
        
        return image, self.images[index]

# Use plt.imshow to visualize images and masks
def show_image_mask(num1, num2, label):
    plt.clf() # Clean the current figure
    plt.title(label)
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.plot(num1, label='Accuracy')
    plt.plot(num2, label='Loss')
    plt.legend(loc='right')
    plt.show()

# Define transformations for images and masks
transform = {
    'image': transforms.Compose([
        transforms.Resize((256,256)),  # Resize images
        transforms.ToTensor()
    ]),
    'mask': transforms.Compose([
        transforms.Resize((256,256)),  # Resize masks to match images
        transforms.ToTensor()
    ])
}

# Initialize dataset and dataloader
train_data = SegmentationDataset(train_img_folder, train_gt_folder, transform=transform)
train_loader = DataLoader(train_data, batch_size=3, shuffle=True)

# Simple CNN model for segmentation
class SimpleSegmentationModel(nn.Module):
    def __init__(self):
        super(SimpleSegmentationModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)  # Adjusted input channels to 3 for RGB
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(128, 64, kernel_size=2, stride=2)
        self.conv6 = nn.Conv2d(64, 1, kernel_size=4, stride=4)  # Output is 1 channel for mask
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.upsample(x)
        x = self.relu(self.conv2(x))
        x = self.upsample(x)
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.relu(self.conv5(x))
        x = self.upsample(x)
        x = torch.sigmoid(self.conv6(x))  # Use sigmoid for binary classification
        return x

def value_accuracy(outputs, masks):
    outputs = outputs > 0.5
    masks = masks > 0.5
    correct = torch.sum(outputs == masks).item()
    total = outputs.numel()
    return correct / total

# Function to save the model
def save_checkpoint(model, epoch, checkpoint_dir='checkpoints'):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{epoch}.pth')
    torch.save(model.state_dict(), checkpoint_path)
    return checkpoint_path

def test_model(model_path):
    test_data = TestDataset(test_dir)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
    model = SimpleSegmentationModel()
    model.load_state_dict(torch.load(model_path))
    model.eval()
    device = 'cuda'
    model.to(device)
    os.makedirs(output_dir, exist_ok=True)
    print('Testing model...')
    with torch.no_grad():
        for images, image_names in test_loader:
            t = 0
            images = images.to(device)
            outputs = model(images)
            for output, image_name in zip(outputs, image_names):
                t += 1
                output = output.squeeze().cpu().numpy()
                output = (output * 255).astype('uint8')
                output_image = Image.fromarray(output)
                output_path = os.path.join(output_dir, image_name)
                output_image.save(output_path)

# Setup device, model, loss function, and optimizer
device = torch.device('cuda')
print(f'Device: {device}')
model = SimpleSegmentationModel().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
checkpoint_path = ''
total_loss = 0
arr_acc = []
arr_loss = []
total_acc = 0
num_epochs = 200
for epoch in range(num_epochs):
    total_num = len(train_loader)
    if (checkpoint_path != ''):
        model.load_state_dict(torch.load(checkpoint_path))
    model.train()
    progress_bar = tqdm(train_loader, unit='batch')
    progress_bar.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
    for images, masks, image_names in train_loader:
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = criterion(outputs, masks)  # Adjust mask dimensions if necessary
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.update()
        total_acc += value_accuracy(outputs, masks)
    sleep(0.5)
    if (epoch + 1) % 1 == 0:
        model_path = save_checkpoint(model, epoch + 1)
    progress_bar.close()
    test_model(model_path)
    print(f'Loss: {total_loss/total_num:.4f} Value_Accuracy: {total_acc/total_num:.4f}')
    arr_acc.append(total_acc/total_num)
    arr_loss.append(total_loss/total_num)
    if(arr_acc.__len__() > 1): 
        show_image_mask(arr_acc, arr_loss, 'Multiple Line Plots')
    total_loss = 0
    total_acc = 0

print('Task completed!')


: 