In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class_labels={0:"Diseased"}

In [None]:
def load_data(image_dir, mask_dir, class_labels):
    image_paths = {}
    mask_paths = {}
    
    # load data based on class labels 
    for cls in class_labels:
        class_dir = os.path.join(image_dir, class_labels[cls])
        mask_dir_cls = os.path.join(mask_dir, class_labels[cls])
        
        image_files = os.listdir(class_dir)
        mask_files = os.listdir(mask_dir_cls)
        
        image_paths[cls] = [os.path.join(class_dir, f) for f in image_files]
        mask_paths[cls] = [os.path.join(mask_dir_cls, f) for f in mask_files]
    
    #load images with their respective masks
    for filename in os.listdir(image_dir):
        if filename.endswith(".jpg") or filename.endswith(".png"):
            image_path = os.path.join(image_dir, filename)
            mask_filename = filename.split('.')[0] + 'png'
            mask_path = os.path.join(mask_dir, mask_filename)
            print(mask_filename)
            if os.path.exists(mask_path):
                if 'other' not in image_paths:
                    image_paths['other'] = []
                    mask_paths['other'] = []
                image_paths['other'].append(image_path)
                mask_paths['other'].append(mask_path)
    
    # Combine all paths
    all_image_paths = []
    all_mask_paths = []
    for cls in image_paths:
        all_image_paths.extend(image_paths[cls])
        all_mask_paths.extend(mask_paths[cls])
    
    return all_image_paths, all_mask_paths


In [None]:
class LeafDataset(data.Dataset):
    def __init__(self, image_dir, mask_dir, class_labels, image_transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.class_labels = class_labels
        self.image_transform = image_transform
        self.mask_transform = mask_transform

        self.image_paths, self.mask_paths = load_data(image_dir, mask_dir, class_labels)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        mask_path = self.mask_paths[index]

        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        mask = torch.where(mask > 0, torch.tensor(1), torch.tensor(0))

        return image, mask

In [None]:
image_transform = transforms.Compose([
    #resize to image 128,128
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])
mask_transform = transforms.Compose([
    #resize to image 128,128
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

In [None]:
leaf_dataset = LeafDataset(image_dir='Images/', mask_dir='masks/', class_labels=class_labels, image_transform=image_transform, mask_transform=mask_transform)

In [None]:
#train validate split 80-20
total_size = len(leaf_dataset)
print(total_size)
train_size = int(0.8 * total_size)
test_size = total_size - train_size

In [None]:
batch_size=64

In [None]:
train_set, test_set = data.random_split(leaf_dataset, [train_size, test_size])
train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [None]:
num_classes=1
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
model.classifier[4] = nn.Sequential(
    nn.Conv2d(256, num_classes, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),

)
print(model)

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
learning_rate = 0.001
early_stop_counter = 0
early_stop_patience = 5
best_valid_loss = float('inf')
num_epochs = 100


In [None]:
#displays image,true mask and generated mask
def visualize_sample(image, mask_true, mask_generated):
    with torch.no_grad():  
        if image.device.type == 'cuda': 
            image = image.cpu()
            mask_true = mask_true.cpu()
            mask_generated = mask_generated.cpu()

        image = np.transpose(image.numpy(), (1, 2, 0))  
        mask_true = mask_true.squeeze()
        mask_generated = mask_generated.squeeze()

        fig, axes = plt.subplots(1, 3, figsize=(7, 7))  

        axes[0].imshow(image)
        axes[0].set_title('Original Image')
        axes[0].axis('off')

        axes[1].imshow(mask_true, cmap='gray')
        axes[1].set_title('True Mask')
        axes[1].axis('off')

        axes[2].imshow(mask_generated, cmap='gray')
        axes[2].set_title('Generated Mask')
        axes[2].axis('off')

        plt.tight_layout()  
        plt.show()

In [None]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images, masks in train_loader:
        optimizer.zero_grad()
        outputs = model(images)['out']
        
        masks = masks.float()
        
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
        
        print("training")
        visualize_sample(images[0], masks[0], outputs[0].sigmoid())

    train_loss /= len(train_loader.dataset)

    model.eval()
    valid_loss = 0.0
    with torch.no_grad():
        for images, masks in test_loader:
            outputs = model(images)['out']
            
            masks = masks.float()
            
            loss = criterion(outputs, masks)
            valid_loss += loss.item() * images.size(0)
            
            print("validating")
            visualize_sample(images[0], masks[0], outputs[0].sigmoid())

    valid_loss /= len(test_loader.dataset)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}')

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        early_stop_counter = 0
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        early_stop_counter += 1
        if early_stop_counter >= early_stop_patience:
            print("Early stopping triggered!")
            break

model.load_state_dict(torch.load('best_model.pt'))
