<a href="https://colab.research.google.com/github/BlueBerry-Coder/Practice/blob/main/Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from volumentations import Compose, Rotate, RandomCrop, Transpose
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET3D
from loss import BCEDiceLoss

In [None]:
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_images,
)

LEARNING_RATE = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1
NUM_EPOCHS = 10
NUM_WORKERS = 6
IMAGE_DEPTH = 128
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
FEATURES = [32, 64, 128, 256]
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = r"D:/PyCode/Project/data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
VAL_IMG_DIR = r"D:/PyCode/Project/data/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationDataну "

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        #print(data.shape, targets.shape)
        targets = targets.float().to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

def main():
    train_transform = Compose([
        Rotate((-30, 30), (-30, 30), (-30, 30), p=1.0),
        RandomCrop((128, 128, 128), p=1.0),
        Transpose(axes=(3,0,1,2), always_apply=True)
        ]
    )

    val_transform = transforms = Compose([
        RandomCrop((128, 128, 128), p=1.0),
        Transpose(axes=(3,0,1,2), always_apply=True)
        ]
    )

    model = UNET3D(in_channels=4, out_channels=3, features=FEATURES).to(DEVICE) #
    loss_fn = BCEDiceLoss() #nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        VAL_IMG_DIR,
        BATCH_SIZE,
        train_transform,
        val_transform
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),\
        }
        save_checkpoint(checkpoint)

        check_accuracy(val_loader, model, device=DEVICE)

        # save_predictions_as_images(
        #     val_loader, model, folder="saved_images/", device=DEVICE
        # )

In [None]:
if __name__ == "__main__":
    main()