In [7]:
import torch
from torch import nn
import os
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

class RetinaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)
        image = Image.open(img_path).convert('L')
        mask = Image.open(mask_path).convert('L')

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Genera le coordinate normalizzate per ogni pixel
        coords = np.meshgrid(np.linspace(-1, 1, image.shape[2]), np.linspace(-1, 1, image.shape[1]))
        coords = np.stack(coords, axis=-1)
        coords = torch.FloatTensor(coords)  # Converte in tensori di torch
        coords = coords.view(-1, 2)  # Flatten delle coordinate in [N, 2] dove N è il numero di pixel

        return coords, image.view(-1, 1), mask.view(-1, 1)

class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        self.in_features = in_features 
        self.omega_0 = omega_0
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.init_weights(is_first)

    def init_weights(self, is_first):
        with torch.no_grad():
            if is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                            np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, x):
        return torch.sin(self.omega_0 * self.linear(x))

class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, first_omega_0=30, hidden_omega_0=30):
        super().__init__()
        self.layers = nn.ModuleList([SineLayer(in_features, hidden_features, is_first=True, omega_0=first_omega_0)])
        for _ in range(hidden_layers):
            self.layers.append(SineLayer(hidden_features, hidden_features, omega_0=hidden_omega_0))
        self.final_layer = nn.Linear(hidden_features, out_features)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.final_layer(x)

def train(model, dataloader, optimizer, scheduler, epochs=5):
    model.train()
    for epoch in range(epochs):
        for coords, image, mask in dataloader:
            coords, image, mask = coords.to(device), image.to(device), mask.to(device)
            optimizer.zero_grad()
            outputs = model(coords)
            loss = nn.MSELoss()(outputs, mask)  # Calcola la perdita usando la maschera come target
            loss.backward()
            optimizer.step()
            scheduler.step(loss)  # Passa la perdita al scheduler
            print("Current learning rate:", scheduler.optimizer.param_groups[0]['lr'])
            print(f"Epoch {epoch}, Loss: {loss.item()}")

# Settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_input = 'C:/Users/Q540900/Desktop/A.I. Master/Werkstudent/RAVIR Dataset/train/training_images'
train_output = 'C:/Users/Q540900/Desktop/A.I. Master/Werkstudent/RAVIR Dataset/train/training_masks'

transform = Compose([
    # Resize((768, 768)), #256
    ToTensor(),
    Normalize(mean=[0.4227], std=[0.1457])
])

train_dataset = RetinaDataset(train_input, train_output, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
2843
# Model
model = Siren(in_features=2, hidden_features=768, hidden_layers=3, out_features=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.8)

# Training
train(model, train_loader, optimizer, scheduler, epochs=30)


Current learning rate: 0.001
Epoch 0, Loss: 10.431130409240723
Current learning rate: 0.001
Epoch 0, Loss: 15.18659782409668
Current learning rate: 0.001
Epoch 0, Loss: 13.843182563781738
Current learning rate: 0.001
Epoch 0, Loss: 10.250557899475098
Current learning rate: 0.001
Epoch 0, Loss: 9.32960319519043
Current learning rate: 0.001
Epoch 0, Loss: 8.972789764404297
Current learning rate: 0.001
Epoch 0, Loss: 8.630720138549805
Current learning rate: 0.001
Epoch 0, Loss: 8.612722396850586
Current learning rate: 0.001
Epoch 0, Loss: 8.787675857543945
Current learning rate: 0.001
Epoch 0, Loss: 8.331501007080078
Current learning rate: 0.001
Epoch 0, Loss: 8.492801666259766
Current learning rate: 0.001
Epoch 0, Loss: 8.275425910949707
Current learning rate: 0.001
Epoch 0, Loss: 8.446663856506348
Current learning rate: 0.001
Epoch 0, Loss: 8.226096153259277
Current learning rate: 0.001
Epoch 0, Loss: 8.39275074005127
Current learning rate: 0.001
Epoch 0, Loss: 8.24657917022705
Current 

In [8]:
torch.save(model.state_dict(), 'INR_model.pth')