In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
import numpy as np
from tqdm import tqdm
import os
from sklearn.model_selection import train_test_split 
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"



class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)
        )
    def forward(self, x): return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x): return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128); self.down2 = Down(128, 256)
        self.down3 = Down(256, 512); self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512); self.up2 = Up(512, 256)
        self.up3 = Up(256, 128); self.up4 = Up(128, 64)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x); x2 = self.down1(x1); x3 = self.down2(x2)
        x4 = self.down3(x3); x5 = self.down4(x4)
        x = self.up1(x5, x4); x = self.up2(x, x3)
        x = self.up3(x, x2); x = self.up4(x, x1)
        return self.outc(x)



class BratsDataset(Dataset):
    def __init__(self, patient_dirs):
        super().__init__()
        self.patient_dirs = patient_dirs
        self.image_paths = []
        self.mask_paths = []
        self.num_slices_per_scan = 155

        
        

        print("Veri seti taraniyor ve dosyalar dogrulaniyor...")
        
        for patient_dir in tqdm(self.patient_dirs, desc="Hastalar Taranıyor"):
            
            flair_files = list(patient_dir.glob('*_flair.nii'))
            seg_files = list(patient_dir.glob('*_seg.nii')) + list(patient_dir.glob('*Segm.nii'))

            
            if len(flair_files) == 1 and len(seg_files) == 1:
                self.image_paths.append(flair_files[0])
                self.mask_paths.append(seg_files[0])

        print(f"Dogrulama tamamlandi. Toplam {len(self.image_paths)} adet gecerli hasta bulundu.")

    def __len__(self):
        
        return len(self.image_paths) * self.num_slices_per_scan

    def __getitem__(self, idx):
        patient_idx = idx // self.num_slices_per_scan
        slice_idx = idx % self.num_slices_per_scan

        image_path = self.image_paths[patient_idx]
        mask_path = self.mask_paths[patient_idx] 

        image_nii = nib.load(image_path); mask_nii = nib.load(mask_path)
        image_data_3d = image_nii.get_fdata(); mask_data_3d = mask_nii.get_fdata()
        
        image_slice = image_data_3d[:, :, slice_idx]
        mask_slice = mask_data_3d[:, :, slice_idx]
        
        image_tensor = torch.from_numpy(image_slice.copy()).float().unsqueeze(0)
        mask_tensor = torch.from_numpy(mask_slice.copy()).long()
        
        mask_tensor[mask_tensor == 4] = 3
        
        return {'image': image_tensor, 'mask': mask_tensor}



try:
    
    data_folder = Path('./data/MICCAI_BraTS2020_TrainingData/')
    
    all_patient_dirs = [d for d in data_folder.iterdir() if d.is_dir()]

    
    train_dirs, val_dirs = train_test_split(all_patient_dirs, test_size=0.2, random_state=42)

    print(f"Toplam {len(all_patient_dirs)} hasta bulundu.")
    print(f"{len(train_dirs)} hasta egitim icin, {len(val_dirs)} hasta validasyon için ayrildi.")

    
    print("\n--- Eğitim Veri Seti Yükleniyor ---")
    train_dataset = BratsDataset(patient_dirs=train_dirs)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
    
    print("\n--- Validasyon Veri Seti Yukleniyor ---")
    validation_dataset = BratsDataset(patient_dirs=val_dirs)
    validation_loader = DataLoader(validation_dataset, batch_size=16, shuffle=False, num_workers=0)
    
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'\nUsing device: {device}')
    
    model = UNet(n_channels=1, n_classes=4).to(device)
    
    
    print("\nYeni bir model oluşturuldu. Egitime sifirdan baslaniyor...")

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    num_epochs = 5

    
    for epoch in range(num_epochs):
        
        model.train() 
        epoch_train_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Eğitim]"):
            images = batch['image'].to(device=device)
            true_masks = batch['mask'].to(device=device)
        
            predicted_masks = model(images)
            loss = criterion(predicted_masks, true_masks)
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_train_loss += loss.item()
        
        avg_train_loss = epoch_train_loss / len(train_loader)

        
        model.eval() 
        epoch_val_loss = 0
        with torch.no_grad(): 
            for batch in tqdm(validation_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Doğrulama]"):
                images = batch['image'].to(device=device)
                true_masks = batch['mask'].to(device=device)
                
                predicted_masks = model(images)
                loss = criterion(predicted_masks, true_masks)
                
                epoch_val_loss += loss.item()

        avg_val_loss = epoch_val_loss / len(validation_loader)

        
        print(f"Epoch {epoch + 1}/{num_epochs} -> "
              f"Ortalama Egitim Kaybi (Train Loss): {avg_train_loss:.4f}, "
              f"Ortalama Dogrulama Kaybi (Validation Loss): {avg_val_loss:.4f}")

        torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")
        print(f"Model 'model_epoch_{epoch+1}.pth' olarak kaydedildi.\n")

    print("Training and validation finished!")

except Exception as e:
    print("\nEgitim sirasinda bir hata olustu:")
    print(e)

Toplam 369 hasta bulundu.
295 hasta egitim icin, 74 hasta validasyon için ayrildi.

--- Eğitim Veri Seti Yükleniyor ---
Veri seti taraniyor ve dosyalar dogrulaniyor...


Hastalar Taranıyor: 100%|██████████| 295/295 [00:00<00:00, 3778.01it/s]


Dogrulama tamamlandi. Toplam 295 adet gecerli hasta bulundu.

--- Validasyon Veri Seti Yukleniyor ---
Veri seti taraniyor ve dosyalar dogrulaniyor...


Hastalar Taranıyor: 100%|██████████| 74/74 [00:00<00:00, 3449.61it/s]

Dogrulama tamamlandi. Toplam 74 adet gecerli hasta bulundu.

Using device: cuda






Yeni bir model oluşturuldu. Egitime sifirdan baslaniyor...


Epoch 1/5 [Eğitim]:   5%|▌         | 143/2858 [02:15<37:29,  1.21it/s] 

In [None]:

MODEL_PATH = "model_checkpoint.pth"
torch.save(model.state_dict(), MODEL_PATH)

print(f"Modelin mevcut durumu {MODEL_PATH} dosyasina kaydedildi!")

Modelin mevcut durumu model_checkpoint.pth dosyasina kaydedildi!
