In [1]:
import torch
%run dataset.ipynb
import torchvision
from torch.utils.data import DataLoader



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def save_checkpoint(state , filename="my_checkpoint_scheduler.pth.tar"):
    print("Saving checkpoint")
    torch.save(state,filename)


In [3]:
def load_checkpoint(checkpoint , model):
    print("loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

    

In [5]:
def get_loaders(
        train_dir,
        train_maskdir,
        val_dir,
        val_maskdir,
        batch_size,
        train_transform,
        val_transform,
        num_workers = 0,
        pin_memory = True,
):
    train_ds = Medical(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform = train_transform,
        )
    train_loader = DataLoader(
        train_ds,
        batch_size = batch_size,
        num_workers = num_workers,
        pin_memory = pin_memory,
        shuffle = True,
    )

    val_ds= Medical(
        image_dir = val_dir,
        mask_dir = val_maskdir,
        transform = val_transform,

    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle = False,

    )

    return train_loader , val_loader

    

In [6]:
def check_accuracy(loader , model , loss_fn ,device = "cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x,y in loader:
            x = x.to(device)
            y= y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds> 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )


In [None]:
def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    # model.train()