In [120]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim  as optim
from PIL import Image
from MyDataset import MyDataset
from glob import glob
import os
from tqdm import tqdm
from matplotlib import pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [81]:
images_path = r'C:\my files\REFUGE'

In [86]:
train_images_path = 'Training400/**/*.jpg'
train_masks_path = 'Annotation-Training400/Disc_Cup_Masks/**/*.bmp'
val_images_path ='REFUGE-Validation400/**/*.jpg'
val_masks_path = 'REFUGE-Validation400-GT/**/*.bmp'

In [87]:
train_images = sorted(glob(os.path.join(images_path, train_images_path), recursive=True))
train_masks = sorted(glob(os.path.join(images_path, train_masks_path), recursive=True))
val_images = sorted(glob(os.path.join(images_path, val_images_path), recursive=True))
val_masks = sorted(glob(os.path.join(images_path, val_masks_path), recursive=True))

In [117]:
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
])

train_dataset = MyDataset(train_images, train_masks, data_transforms, data_transforms)
val_dataset = MyDataset(val_images, val_masks, data_transforms, data_transforms)
train_loader = DataLoader(train_dataset, batch_size=16)
val_loader = DataLoader(val_dataset, batch_size=16)

In [112]:
class SegmentationModel(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SegmentationModel, self).__init__()


        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, out_channels, kernel_size=1))

    def forward(self, x):
        x_size = x.size()
        # Encoder
        x = self.encoder(x)

        # Decoder
        x = self.decoder(x)

        x = F.interpolate(x, x_size[2:],mode='bilinear', align_corners=True )

        return torch.squeeze(x)


In [None]:
import torch.optim as optim

# Define your model
model = SegmentationModel(in_channels=3, out_channels=1).to(device)

# Define your loss function
criterion = nn.BCEWithLogitsLoss().to(device)

# Define your optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define number of epochs
num_epochs = 10

# Train the model
for epoch in range(num_epochs):
    # Train
    model.train()
    train_loss = 0.0
    for images, masks in tqdm(train_loader):
        images = images.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Validate
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()
        val_loss /= len(val_loader)

    # Print progress
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")


100%|██████████| 25/25 [01:35<00:00,  3.83s/it]


Epoch [1/10], Train Loss: 0.8042, Val Loss: 0.7377


100%|██████████| 25/25 [01:40<00:00,  4.01s/it]
