In [2]:
import torch
%run dataset.ipynb
import torchvision
from torch.utils.data import DataLoader



In [3]:
def save_checkpoint(state , filename="my_checkpointm2.pth.tar"):
    print("Saving checkpoint")
    torch.save(state,filename)


In [4]:
def load_checkpoint(checkpoint , model):
    print("loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

    

In [5]:
def get_loaders(
        train_dir,
        csv_file_path,
        val_dir,
        val_csv_file_path,
        batch_size,
        train_transform,
        val_transform,
        num_workers = 0,
        pin_memory = True,
):
    train_ds = CephXrayDataset(
        csv_file_path = csv_file_path,
        image_path =train_dir,
        transform = train_transform,
        )
    train_loader = DataLoader(
        train_ds,
        batch_size = batch_size,
        num_workers = num_workers,
        pin_memory = pin_memory,
        shuffle = True,
    )

    val_ds= CephXrayDataset(
        csv_file_path = val_csv_file_path,
        image_path =val_dir,
        transform =val_transform,

    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle = False,

    )

    return train_loader , val_loader

    

In [None]:
import numpy as np

In [None]:
def check_accuracy(loader , model ,loss_fn ,device = "cuda"):
    model.eval()
    valid_losses=[]

    with torch.no_grad():
        for (image,heatmap)in loader:
            image, heatmap = image.float().to(device = device), heatmap.float().to(device = device)
            output = model(image)
            loss = loss_fn(output,heatmap)
            valid_losses.append(loss.item())
        valid_loss = np.mean(valid_losses)
        print('Validation loss: {:.6f}'.format(valid_loss))


<!-- def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (image, heatmap) in enumerate(loader):
        image, heatmap = image.float().to(device = device), heatmap.float().to(device = device)
        with torch.no_grad():
            preds = model(image)a

# Save the predictions
            heatmap_sum = heatmap.sum(axis=1, keepdims=True)

# Now the heatmap has the same number of channels as the image and they can be combined
            imposed_image = image + heatmap_sum

            preds_sum = preds.sum(axis=1, keepdims=True)
            imposed_image_preds = image + preds_sum
            torchvision.utils.save_image(imposed_image_preds, f"{folder}/imposed_preds_{idx}.png")
            
                
            torchvision.utils.save_image(imposed_image, f"{folder}{idx}.png")
         -->

In [None]:
def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    num_correct= []
    num_pixels = []
    model.eval()
    for idx, (image, heatmap) in enumerate(loader):
        image, heatmap = image.float().to(device = device), heatmap.float().to(device = device)
        with torch.no_grad():
            preds = model(image)
            predss = (preds> 0.5).float()
            coords = torch.nonzero(predss)
            print(f"Coordinates of predicted pixels: {coords}")

            heatmap_sum = heatmap.sum(axis=1, keepdims=True)

            imposed_image = image + heatmap_sum

            preds_sum = preds.sum(axis=1, keepdims=True)
            imposed_image_preds = image + preds_sum
            torchvision.utils.save_image(imposed_image_preds, f"{folder}/imposed_preds_{idx}.png")
            
                
            torchvision.utils.save_image(imposed_image, f"{folder}/imposed_actual_{idx}.png")
            # torchvision.utils.save_image(image_with_clusters, f"{folder}{idx}.png")

        