In [9]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import tqdm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
%run model.ipynb
%run utils.ipynb


UNET(
  (ups): ModuleList(
    (0): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    (1): DoubleConv(
      (relusig): ReluSIG(
        (gelu): GELU(approximate='none')
        (sigmoid): Sigmoid()
      )
      (conv): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReluSIG(
          (gelu): GELU(approximate='none')
          (sigmoid): Sigmoid()
        )
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReluSIG(
          (gelu): GELU(approximate='none')
          (sigmoid): Sigmoid()
        )
      )
    )
    (2): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (3): DoubleConv(
      (relusig): ReluSIG(
        (gelu): GELU(approximate=

In [6]:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
NUM_EPOCHS = 100
NUM_WORKERS = 0
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 160
PIN_MEMORY = True
LOAD_MODEL = 1
TRAIN_IMG_DIR= "images"
TRAIN_MASK_DIR = "masks"
VAL_MASK_DIR = "masks_val"
VAL_IMG_DIR ="images_val"
hyp1 = "C:/Users/haZAR/Desktop/Subahnshu_Sethi_ResearchTeam/Segmentation/Model_weights/Hypothises1CONCATskipAND ATTENTION.pth.tar"
hyp2 = "C:/Users/haZAR/Desktop/Subahnshu_Sethi_ResearchTeam/Segmentation/Model_weights/Hypothisesconcatskip_att_gelusig.pth.tar"




In [7]:

def train_fn(loader, model, optimizer, loss_fn):

    loop = tqdm.tqdm(loader)
    losses = []

    for (data,targets) in (loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)
        predictions = model(data)

        loss = loss_fn(predictions, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        losses.append(loss.item())
        loop.set_postfix(loss=loss.item())

    return sum(losses) / len(losses)
        



In [8]:
def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT,width=IMAGE_WIDTH),
            A.HorizontalFlip(p=0.03),
            ToTensorV2(),
        ],

    )
    val_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT,width=IMAGE_WIDTH),
            ToTensorV2(),
        ],
        
    )
    model = load_model('UNet').to(device= DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    train_loader , val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transform,
        NUM_WORKERS,
        PIN_MEMORY
    )

    if LOAD_MODEL == 1:
        load_checkpoint(torch.load(hyp1), model)
    elif LOAD_MODEL == 2 :
        load_checkpoint(torch.load(hyp2), model)
        
    check_accuracy(val_loader,model,loss_fn,device=DEVICE)
    losses = []

    for epochs in range(NUM_EPOCHS):
        model.train()
        loss = train_fn(train_loader, model, optimizer, loss_fn)
        losses.append(loss)
        checkpoint = {
            "state_dict":model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }

        save_checkpoint(checkpoint)

        check_accuracy(val_loader,model,device = DEVICE)

        save_predictions_as_imgs(
            val_loader,model,folder = "saved_images/" ,device = DEVICE

        )
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.show()
        
if __name__ == "__main__":
    main()



    

loading checkpoint
Got 1828048/1868800 with acc 97.82


  8%|▊         | 11/138 [00:03<00:35,  3.56it/s, loss=0.032] 


KeyboardInterrupt: 