In [36]:
import torch
from torch import nn
from torchvision import models
from torchvision.models import resnet50, ResNet50_Weights
import numpy as np
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
from PIL import Image
from torch.utils.data import DataLoader, Dataset

class RetinaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Converti l'immagine in scala di grigi a 3 canali
        # image = image.repeat(3, 1, 1)  # Replicazione del canale per avere 3 canali

        # Usa le dimensioni H e W dal modello
        H, W = 192, 192  # Assicurati che queste siano le dimensioni di output del feature extractor
        coords = np.meshgrid(np.linspace(-1, 1, W), np.linspace(-1, 1, H))
        coords = np.stack(coords, axis=-1).reshape(-1, 2)
        coords = torch.FloatTensor(coords)  # Converte in tensori di torch

        return image, coords, mask.view(-1, 1)




# SineLayer definito per un utilizzo generale nei modelli
class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        print("Initializing SineLayer with in_features:", in_features, "and out_features:", out_features)

        self.in_features = in_features  # Aggiunta di questa riga
        self.omega_0 = omega_0
        self.is_first = is_first
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, x):
        print("Input to SineLayer:", x.shape)
        x = self.linear(x)
        print("Output from Linear:", x.shape)
        return torch.sin(self.omega_0 * x)


# LocalImplicitModel che utilizza SineLayer e altri strati lineari
class LocalImplicitModel(nn.Module):
    def __init__(self, feature_channels, hidden_units=256):
        super().__init__()
        # Il primo layer combina 2 dimensioni delle coordinate con il numero di canali delle features estratte
        self.fc1 = SineLayer(2 + feature_channels, hidden_units, is_first=True)  # 2 (coords) + 256 (features) = 258
        self.fc2 = SineLayer(hidden_units, hidden_units)
        self.output_layer = nn.Linear(hidden_units, 1)
        self.output_activation = nn.Sigmoid()

    def forward(self, features, coords):
        # Assicurati che il reshaping sia corretto
        batch_size, num_pixels, channels = features.size()
        features = features.view(-1, channels)
        coords = coords.view(-1, 2)
        x = torch.cat([coords, features], dim=1)  # Concatenazione delle features con le coordinate
        print("Step 1:", x.shape)
        x = self.fc1(x)
        print("Step 2:", x.shape)
        x = self.fc2(x)
        print("Step 3:", x.shape)
        x = self.output_layer(x)
        print("Step 4:", x.shape)
        x = self.output_activation(x)
        print("Step 5:", x.shape)
        return x



# Feature Extractor basato su ResNet pre-addestrato
def create_feature_extractor():
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    feature_extractor = nn.Sequential(*list(model.children())[:5])  # Usiamo fino al layer 7
    for param in feature_extractor.parameters():
        param.requires_grad = False
    return feature_extractor

# Classe LIIFSegmentation per integrare il tutto
class LIIFSegmentation(nn.Module):
    def __init__(self, feature_extractor, local_implicit_model):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.local_implicit_model = local_implicit_model

    def forward(self, x, coords):
        features = self.feature_extractor(x)  # Estrai le feature dell'immagine
        # Precedentemente era errato perché si assumeva che features avesse quattro dimensioni al momento dello scompattamento
        batch_size, channels, height, width = features.size()  # Corretto
        print('Batch_size:', batch_size)
        print('Channels:', channels)
        print('Height:', height)
        print('Width:', width)
        features = features.view(batch_size, channels, -1).permute(0, 2, 1)  # Da [B, C, H, W] a [B, H*W, C]
        output = self.local_implicit_model(features, coords)  # Calcola l'output per ogni pixel
        output = output.view(batch_size, height, width)
        return output

    

def train(model, dataloader, optimizer, scheduler, epochs=5):
    model.train()
    for epoch in range(epochs):
        for image, coords, mask in dataloader:
            image, coords, mask = image.to(device), coords.to(device), mask.to(device)
            optimizer.zero_grad()
            outputs = model(image, coords)
            loss = nn.MSELoss()(outputs, mask)
            loss.backward()
            optimizer.step()
            scheduler.step(loss)
            print("Current learning rate:", scheduler.optimizer.param_groups[0]['lr'])
            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Setup del modello
feature_extractor = create_feature_extractor()
local_implicit_model = LocalImplicitModel(feature_channels=256, hidden_units=256)  # Assumendo 1024 canali dal feature extractor

model = LIIFSegmentation(feature_extractor, local_implicit_model)

# Assumi che x sia il tuo input di immagine e coords siano le coordinate dei pixel
# output = model(x, coords)

train_input = 'C:/Users/Q540900/Desktop/A.I. Master/Werkstudent/RAVIR Dataset/train/training_images'
train_output = 'C:/Users/Q540900/Desktop/A.I. Master/Werkstudent/RAVIR Dataset/train/training_masks'

transform = Compose([
    Resize((192, 192)), #256
    ToTensor(),
    Normalize(mean=[0.5], std=[0.5])
])

train_dataset = RetinaDataset(train_input, train_output, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5)

train(model, train_loader, optimizer, scheduler, epochs=30)

Initializing SineLayer with in_features: 258 and out_features: 256
Initializing SineLayer with in_features: 256 and out_features: 256
Batch_size: 1
Channels: 256
Height: 48
Width: 48


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 36864 but got size 2304 for tensor number 1 in the list.

In [None]:
torch.save(model.state_dict(), 'retina_segmentation_model_LIIF.pth')

In [None]:
test_input = 'C:/Users/Q540900/Desktop/A.I. Master/Werkstudent/RAVIR Dataset/test'

In [None]:
import torch
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
from PIL import Image
import os
import matplotlib.pyplot as plt

# Carica il modello
model = LIIFSegmentation(feature_extractor, local_implicit_model)
model.load_state_dict(torch.load('retina_segmentation_model_LIIF.pth'))
model.eval()  # Set the model to evaluation mode
model.to(device)

# Funzione per processare e visualizzare le immagini
def process_and_visualize(test_path):
    transform = Compose([
        Resize((256, 256)),
        ToTensor(),
        Normalize(mean=[0.5], std=[0.5])
    ])

    images = os.listdir(test_path)
    for img_name in images:
        img_path = os.path.join(test_path, img_name)
        image = Image.open(img_path).convert('L')
        image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
        
        coords = np.meshgrid(np.linspace(-1, 1, image_tensor.shape[3]), np.linspace(-1, 1, image_tensor.shape[2]))
        coords = np.stack(coords, axis=-1)
        coords = torch.FloatTensor(coords).to(device)
        coords = coords.view(1, -1, 2)  # [1, H*W, 2]

        with torch.no_grad():
            output = model(coords)
            output = output.view(1, 256, 256)  # Reshape output to image shape

        # Visualizzazione delle immagini
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        ax[0].imshow(image, cmap='gray')
        ax[0].set_title('Input Image')
        ax[0].axis('off')

        output_img = output.cpu().squeeze().numpy()
        ax[1].imshow(output_img, cmap='gray')
        ax[1].set_title('Output Image')
        ax[1].axis('off')

        plt.show()

# Chiamata alla funzione con il percorso di test
process_and_visualize(test_input)
