In [3]:
from transformers.modeling_outputs import SemanticSegmenterOutput  
from transformers import Dinov2Model, Dinov2PreTrainedModel  
from retouch_dataloader_utils import load_train_and_val  
from torch.utils.tensorboard import SummaryWriter
from aroi_dataloder_utils import load_val
from torch.utils.data import DataLoader  
from torch.utils.data import Dataset  
import torch.nn.functional as F  
from torch.optim import AdamW  
from tqdm.auto import tqdm  
import albumentations as A 
from PIL import Image  
import pandas as pd
import numpy as np  
import torch  
import cv2 
import csv 
import os 


if not torch.cuda.is_available():
    print("Használjon GPU-t a modell tanításához! úgy sokkal gyorsabb...")

In [None]:
class SegmentationDataset(Dataset): 
  def __init__(self, dataset, transform): 
    self.dataset = dataset 
    self.transform = transform 
 
  def __len__(self): 
    return len(self.dataset) 
 
  def __getitem__(self, idx): 
    item = self.dataset[idx] 

    original_image= np.load(item["image_path"])
    original_image = np.stack([original_image] * 3, axis=-1)

    original_segmentation_map = np.load(item["label_path"])
    transformed = self.transform(image=original_image, mask=original_segmentation_map)
    image, target = torch.tensor(transformed['image']), torch.LongTensor(transformed['mask']) 
 
    image = image.permute(2, 0, 1)
    image_path=item["image_path"] 
      
    return image, target, original_image, original_segmentation_map, image_path

In [None]:
def collate_fn(inputs): 
    batch = dict() 
    batch["pixel_values"] = torch.stack([i[0] for i in inputs], dim=0) 
    batch["labels"] = torch.stack([i[1] for i in inputs], dim=0) 
    batch["original_images"] = [i[2] for i in inputs] 
    batch["original_segmentation_maps"] = [i[3] for i in inputs]
    batch["image_path"] = [i[4] for i in inputs]
 
    return batch 

In [None]:
class LinearClassifier(torch.nn.Module): 
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1): 
        super(LinearClassifier, self).__init__() 
 
        self.in_channels = in_channels 
        self.width = tokenW 
        self.height = tokenH 

        #Két-réteg:
        
        #self.conv1 = torch.nn.Conv2d(in_channels, 64, (6,6), padding=1)
        #self.conv2 = torch.nn.Conv2d(64, 128, (6,6), padding=1)
        #self.classifier = torch.nn.Conv2d(128, num_labels, (1,1))

        #Egy réteg:
        self.conv = torch.nn.Conv2d(in_channels, 128, (6,6), padding=1)
        self.classifier = torch.nn.Conv2d(128, num_labels, (1,1))

    def forward(self, embeddings): 
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels) 
        embeddings = embeddings.permute(0,3,1,2) 

        #Két-réteg:
        #x = torch.relu(self.conv1(embeddings))
        #x = torch.relu(self.conv2(x))
        #return self.classifier(x)
        
        #Egy-réteg
        x = torch.relu(self.conv(embeddings))
        return self.classifier(x)

In [None]:
class Dinov2ForSemanticSegmentation(Dinov2PreTrainedModel): 
  def __init__(self, config): 
    super().__init__(config) 
 
    self.dinov2 = Dinov2Model(config) 
    self.classifier = LinearClassifier(config.hidden_size, 32, 32, config.num_labels) 
 
  def forward(self, pixel_values, output_hidden_states=False, output_attentions=False, labels=None): 
    outputs = self.dinov2(pixel_values, 
                            output_hidden_states=output_hidden_states, 
                            output_attentions=output_attentions) 
    patch_embeddings = outputs.last_hidden_state[:,1:,:] 
 
    logits = self.classifier(patch_embeddings) 
    logits = torch.nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False) 
 
    loss = None 
    if labels is not None: 
      loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0) 
      loss = loss_fct(logits.squeeze(), labels.squeeze()) 
 
    return SemanticSegmenterOutput( 
        loss=loss, 
        logits=logits, 
        hidden_states=outputs.hidden_states, 
        attentions=outputs.attentions, 
    ) 

In [None]:
def dice_loss(preds, targets, num_classes, smooth=1.0):
    preds_softmax = F.softmax(preds, dim=1).float()
    targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()

    # Dice loss számítás
    intersection = (preds_softmax * targets_one_hot).sum(dim=(2, 3))
    union = preds_softmax.sum(dim=(2, 3)) + targets_one_hot.sum(dim=(2, 3))

    dice = (2.0 * intersection + smooth) / (union + smooth)
    dice_loss = 1 - dice.mean()
    
    return dice_loss

In [None]:
def compute_dice_coefficient(preds, targets, num_classes, eps=1e-6):
    preds_one_hot = torch.nn.functional.one_hot(preds, num_classes).permute(0, 3, 1, 2).float()
    targets_one_hot = torch.nn.functional.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()

    intersection = (preds_one_hot * targets_one_hot).sum(dim=(2, 3))
    union = preds_one_hot.sum(dim=(2, 3)) + targets_one_hot.sum(dim=(2, 3))

    dice = (2.0 * intersection + eps) / (union + eps)
    return dice.mean().item()

In [None]:
def create_tensorboard_writer(logging_dir_name):
    return SummaryWriter(log_dir="runs/"+logging_dir_name)  # A logokat ebbe a mappába menti

In [None]:
def training(epochs,learning_rate,cross_val,train_dataloader):
    logging_dir_name=str(epochs)+'_'+str(learning_rate)+'_'+str(cross_val)
    writer=create_tensorboard_writer(logging_dir_name)
    
    model = Dinov2ForSemanticSegmentation.from_pretrained("facebook/dinov2-base", id2label=id2label, num_labels=len(id2label)) 
     
    for name, param in model.named_parameters(): 
      if name.startswith("dinov2"): 
        param.requires_grad = False 
    
    learning_rate=float(learning_rate)

    optimizer = AdamW(model.parameters(), lr=learning_rate)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.train()
    
    ce_loss_fn = torch.nn.CrossEntropyLoss(ignore_index=0)

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
    
        dice_score_avg = []
        loss_avg = []
        print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        for idx, batch in enumerate(tqdm(train_dataloader)):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
    
            outputs = model(pixel_values, labels=labels)
            logits = outputs.logits
            loss_ce = ce_loss_fn(logits.squeeze(), labels.squeeze())
    
            preds = logits
    
            loss_dice = dice_loss(preds, labels, num_classes=len(id2label))
    
            loss = loss_ce + loss_dice
            loss_avg.append(loss.item())
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    
            with torch.no_grad():
                preds=logits.argmax(dim=1)
                dice_score = compute_dice_coefficient(preds.detach().cpu(), labels.detach().cpu(), num_classes=len(id2label))
                dice_score_avg.append(dice_score)
                
        print("Dice score average",np.nanmean(dice_score_avg))
        print("Loss average",np.nanmean(loss_avg))
        
        writer.add_scalar("Loss/train", np.nanmean(loss_avg), epoch)
        writer.add_scalar("Accuracy/train", np.nanmean(dice_score_avg), epoch)
        
    writer.close()
    return model

In [None]:
def run_train_with_cross_validation(epochs,learning_rates,cross_validation,image_images,label_dir,model_dir):
    for epoch in epochs:
        for lr in learning_rates:
            inference_values = []
            for cross_val in range(cross_validation):
                dataset = load_train_and_val(image_dir,label_dir,cross_validation,cross_val)

                train_dataset = SegmentationDataset(dataset["train"], transform=train_transform)                 
                train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0, pin_memory=True,collate_fn=collate_fn)                
                model=training(epoch,lr,cross_val,train_dataloader)
                logging_model_name=model_dir+'/'+str(epoch)+'_'+str(lr)+'_'+str(cross_val)+'.pth'
                torch.save(model.state_dict(),logging_model_name)

In [None]:
id2label = {
    0:"Background",
    1:"IRF",
    2:"SRF",
    3:"PED"
}

In [None]:
ADE_MEAN = (np.array([123.675, 116.280, 103.530])).tolist()
ADE_STD = (np.array([58.395, 57.120, 57.375])).tolist()
 
train_transform = A.Compose([ 
    A.Resize(width=448, height=448),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
], is_check_shapes=False) 

In [None]:
image_dir = '.../RETOUCH_TRAINING/imagesTr'
label_dir = '.../RETOUCH_TRAINING/labelsTr'
model_dir = '.../modellek/'


epochs=[5,10,100]
learning_rates=['1e-2','1e-3','1e-4','1e-5']
cross_validation=5
run_train_with_cross_validation(epochs,learning_rates,cross_validation)