In [None]:
from transformers.modeling_outputs import SemanticSegmenterOutput  
from transformers import Dinov2Model, Dinov2PreTrainedModel  
from retouch_dataloader_utils import load_train_and_val  
from torch.utils.tensorboard import SummaryWriter
from aroi_dataloder_utils import load_val
from torch.utils.data import DataLoader  
from torch.utils.data import Dataset  
import torch.nn.functional as F  
from torch.optim import AdamW  
from tqdm.auto import tqdm  
import albumentations as A 
from PIL import Image  
import pandas as pd
import numpy as np  
import torch  
import cv2 
import csv 
import os 


torch.cuda.is_available() 

In [None]:
class SegmentationDataset(Dataset): 
  def __init__(self, dataset, transform): 
    self.dataset = dataset 
    self.transform = transform 
 
  def __len__(self): 
    return len(self.dataset) 
 
  def __getitem__(self, idx): 
    item = self.dataset[idx] 

    original_image= np.load(item["image_path"])
    original_image = np.stack([original_image] * 3, axis=-1)

    original_segmentation_map = np.load(item["label_path"])
    transformed = self.transform(image=original_image, mask=original_segmentation_map)
    image, target = torch.tensor(transformed['image']), torch.LongTensor(transformed['mask']) 
 
    image = image.permute(2, 0, 1)
    image_path=item["image_path"] 
      
    return image, target, original_image, original_segmentation_map, image_path

In [None]:
def collate_fn(inputs): 
    batch = dict() 
    batch["pixel_values"] = torch.stack([i[0] for i in inputs], dim=0) 
    batch["labels"] = torch.stack([i[1] for i in inputs], dim=0) 
    batch["original_images"] = [i[2] for i in inputs] 
    batch["original_segmentation_maps"] = [i[3] for i in inputs]
    batch["image_path"] = [i[4] for i in inputs]
 
    return batch 

In [None]:
class LinearClassifier(torch.nn.Module): 
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1): 
        super(LinearClassifier, self).__init__() 
 
        self.in_channels = in_channels 
        self.width = tokenW 
        self.height = tokenH 

        #Két-réteg:
        
        #self.conv1 = torch.nn.Conv2d(in_channels, 64, (6,6), padding=1)
        #self.conv2 = torch.nn.Conv2d(64, 128, (6,6), padding=1)
        #self.classifier = torch.nn.Conv2d(128, num_labels, (1,1))

        #Egy réteg:
        self.conv = torch.nn.Conv2d(in_channels, 128, (6,6), padding=1)
        self.classifier = torch.nn.Conv2d(128, num_labels, (1,1))

    def forward(self, embeddings): 
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels) 
        embeddings = embeddings.permute(0,3,1,2) 

        #Két-réteg:
        #x = torch.relu(self.conv1(embeddings))
        #x = torch.relu(self.conv2(x))
        #return self.classifier(x)
        
        #Egy-réteg
        x = torch.relu(self.conv(embeddings))
        return self.classifier(x)

In [None]:
class Dinov2ForSemanticSegmentation(Dinov2PreTrainedModel): 
  def __init__(self, config): 
    super().__init__(config) 
 
    self.dinov2 = Dinov2Model(config) 
    self.classifier = LinearClassifier(config.hidden_size, 32, 32, config.num_labels) 
 
  def forward(self, pixel_values, output_hidden_states=False, output_attentions=False, labels=None): 
    outputs = self.dinov2(pixel_values, 
                            output_hidden_states=output_hidden_states, 
                            output_attentions=output_attentions) 
    patch_embeddings = outputs.last_hidden_state[:,1:,:] 
 
    logits = self.classifier(patch_embeddings) 
    logits = torch.nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False) 
 
    loss = None 
    if labels is not None: 
      loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0) 
      loss = loss_fct(logits.squeeze(), labels.squeeze()) 
 
    return SemanticSegmenterOutput( 
        loss=loss, 
        logits=logits, 
        hidden_states=outputs.hidden_states, 
        attentions=outputs.attentions, 
    ) 

In [None]:
def get_scan_number(image_dir,path):
    new_elem=path.replace(image_dir,'')
    elem_list=new_elem.split('/')
    return int(elem_list[1])

In [None]:
def get_scan_index(image_dir,path):
    new_elem=path.replace(image_dir,'')
    elem_list=new_elem.split('/')
    elem_list=elem_list[2].replace('.npy','')
    elem_list=elem_list.split('_')
    elem_list=elem_list[1]
    return int(elem_list)

In [None]:
def compute_dice_for_inference_3d(preds_3d, targets_3d, num_classes):
    preds_3d = preds_3d.unsqueeze(0).long()
    targets_3d = targets_3d.unsqueeze(0).long()

    preds_one_hot = torch.nn.functional.one_hot(preds_3d, num_classes=num_classes)
    targets_one_hot = torch.nn.functional.one_hot(targets_3d, num_classes=num_classes)

    preds_one_hot = preds_one_hot.permute(0, 4, 1, 2, 3).float()
    targets_one_hot = targets_one_hot.permute(0, 4, 1, 2, 3).float()

    intersection = (preds_one_hot * targets_one_hot).sum(dim=(2, 3, 4))
    union = preds_one_hot.sum(dim=(2, 3, 4)) + targets_one_hot.sum(dim=(2, 3, 4))

    dice = (2.0 * intersection) / (union + 1e-6)

    return dice.squeeze(0)

In [None]:
#Load model and run inference on it
def run_3d_inference_on_models(epoch,learning_rate,cross_validation,path,image_dir,label_dir):
    path=path+str(epoch)+'_'+str(learning_rate)+'_'+str(cross_validation)+'.pth'
    
    model = Dinov2ForSemanticSegmentation.from_pretrained("facebook/dinov2-base", id2label=id2label, num_labels=len(id2label))
    model.load_state_dict(torch.load(path, weights_only=True))
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    dataset = load_train_and_val(image_dir,label_dir,5,cross_validation)
    
    for name, param in model.named_parameters(): 
      if name.startswith("dinov2"): 
        param.requires_grad = False 
    
    val_dataset = SegmentationDataset(dataset["validation"], transform=val_transform)
    
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True,collate_fn=collate_fn)
    
    image_dict={}
    label_dict={}
    maxdict={}
    for idx, batch in enumerate(tqdm(val_dataloader)):
        scan_number = get_scan_number(image_dir,batch["image_path"][0])
        scan_index = get_scan_index(image_dir,batch["image_path"][0])
        key=str(scan_number)+'_'+str(scan_index)
        image_dict[key]=batch["pixel_values"]
        label_dict[key]=batch["labels"]
    
        if scan_number not in maxdict:
            maxdict[scan_number] = scan_index
        else:
            maxdict[scan_number] = max(maxdict[scan_number], scan_index)
            
    model = model.to(device)
    
    columns = ["scan_number","Background", "IRF", "SRF", "PED"]
    df = pd.DataFrame(columns=columns)
    
    for idx,key in enumerate(maxdict.keys()):
        value=maxdict[key]
        preds_3d = []
        labels_3d = []
        for val in range(value+1):
            keyval=str(key)+'_'+str(val)
            test_image = image_dict[keyval]
            labels = label_dict[keyval]
            with torch.no_grad():
                outputs = model(test_image.to(device))
                size=test_image.shape[:2]
                logits = outputs.logits
                preds = logits.argmax(dim=1)
                preds_3d.append(preds.squeeze(0).cpu())
                labels_3d.append(labels.squeeze(0).cpu())
    
        preds_3d = torch.stack(preds_3d, dim=0)
        labels_3d = torch.stack(labels_3d, dim=0)
                
        dice_score = compute_dice_for_inference_3d(preds_3d.detach().cpu(), labels_3d.detach().cpu(), num_classes=len(id2label))
        dice_list = [round(val, 4) for val in dice_score.squeeze().tolist()]
        df.loc[idx] = [key,dice_list[0],dice_list[1],dice_list[2],dice_list[3]]
        
    csv_filename=str(epoch)+'_'+str(learning_rate)+'_'+str(cross_validation)+'_3d_inference.csv'
    print(csv_filename,' Is ready!')
    df.to_csv(csv_filename, index=False)

In [None]:
id2label = {
    0:"Background",
    1:"IRF",
    2:"SRF",
    3:"PED"
}

In [None]:
ADE_MEAN = (np.array([123.675, 116.280, 103.530])).tolist()
ADE_STD = (np.array([58.395, 57.120, 57.375])).tolist()

val_transform = A.Compose([ 
    A.Resize(width=448, height=448), 
    A.Normalize(mean=ADE_MEAN, std=ADE_STD), 
], is_check_shapes=False) 

In [None]:
path='.../modellek/ 
image_dir = '.../RETOUCH_TRAINING/imagesTr'
label_dir = '.../RETOUCH_TRAINING/labelsTr'


epochs=[5,10,100]
learning_rates=['1e-2','1e-3','1e-4','1e-5']
cross_validation=5

for epoch in epochs:
    for lr in learning_rates:
        for cross_val in range(cross_validation):
            print(cross_val)
            run_3d_inference_on_models(epoch,lr,cross_val)