In [2]:
import numpy as np
import torch
from tqdm import tqdm

from torch import nn
from torch.optim import Adam
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2
import albumentations as A

from model.FoodDataset import FoodDataset
from model.architectures.unet import Unet_model

In [3]:
image_paths = ["E:\Licenta_DOC\API_Segmentation\data\generated\img_dir"]
seg_paths = ["E:\Licenta_DOC\API_Segmentation\data\generated\\ann_dir"]

LOAD_MODEL = True
def get_images(batch_size=32,shuffle=True,pin_memory=True):
    data = FoodDataset(image_paths, seg_paths,transform = t1)
    train_size = int(0.8 * data.__len__())
    test_size = data.__len__() - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])
    train_batch = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)
    test_batch = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)
    return train_batch,test_batch


In [4]:

t1 = A.Compose([
    A.Resize(256,256),
    A.augmentations.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
])

train_batch,test_batch = get_images(batch_size=8)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"



In [5]:
from model.checkpoints.checkpoints import load_checkpoint
import timm as timm
from torchsummary import summary

from model.architectures.VisionTransformerSegmentation import ViTSegmentationModel

# vit_model = timm.create_model('vit_base_patch16_224', pretrained=True)
num_classes = 104  # Number of classes for the segmentation task
vit_model_name = 'vit_base_patch16_224'
vit_segmentation_model = ViTSegmentationModel(vit_model_name, num_classes).to(DEVICE)
# summary(vit_segmentation_model, (3, 224, 224))

# load_checkpoint(torch.load('checkpoints/checkpoint-vit-ce_dice_loss.pth.tar'), vit_segmentation_model)


In [6]:
summary(vit_segmentation_model,input_size= ( 3, 224, 224))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
          Identity-2             [-1, 196, 768]               0
        PatchEmbed-3             [-1, 196, 768]               0
           Dropout-4             [-1, 197, 768]               0
          Identity-5             [-1, 197, 768]               0
         LayerNorm-6             [-1, 197, 768]           1,536
            Linear-7            [-1, 197, 2304]       1,771,776
           Dropout-8         [-1, 12, 197, 197]               0
            Linear-9             [-1, 197, 768]         590,592
          Dropout-10             [-1, 197, 768]               0
        Attention-11             [-1, 197, 768]               0
         Identity-12             [-1, 197, 768]               0
         Identity-13             [-1, 197, 768]               0
        LayerNorm-14             [-1, 1

In [20]:
from model.architectures.SAM_Architecture import SAM_Architecture

model = SAM_Architecture(104).to(DEVICE) # Works
summary(model, (3, 256, 256))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 16, 16]         590,592
        PatchEmbed-2          [-1, 16, 16, 768]               0
         LayerNorm-3          [-1, 16, 16, 768]           1,536
            Linear-4         [-1, 16, 16, 2304]       1,771,776
            Linear-5          [-1, 16, 16, 768]         590,592
         Attention-6          [-1, 16, 16, 768]               0
         LayerNorm-7          [-1, 16, 16, 768]           1,536
            Linear-8         [-1, 16, 16, 3072]       2,362,368
              GELU-9         [-1, 16, 16, 3072]               0
           Linear-10          [-1, 16, 16, 768]       2,360,064
         MLPBlock-11          [-1, 16, 16, 768]               0
            Block-12          [-1, 16, 16, 768]               0
        LayerNorm-13          [-1, 16, 16, 768]           1,536
           Linear-14         [-1, 16, 1

In [None]:
from model.losses.ce_dice_loss import CE_DICE_Loss
from model.checkpoints.checkpoints import save_checkpoint

LEARNING_RATE = 1e-4
num_epochs = 25

loss_fn = CE_DICE_Loss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(num_epochs):
    loop = tqdm(enumerate(train_batch),total=len(train_batch))
    for batch_idx, (data, targets) in loop:
        data = data.to(DEVICE)
        targets = targets.to(DEVICE)
        targets = targets.type(torch.long)
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


# # save checkpoint
checkpoint =  {"state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict()
           }
save_checkpoint(checkpoint,'checkpoints/checkpoint-vit-ce_dice_loss.pth.tar' )



100%|██████████| 2492/2492 [56:04<00:00,  1.35s/it, loss=2.07] 
 12%|█▏        | 301/2492 [06:23<46:31,  1.27s/it, loss=2.35]


KeyboardInterrupt: 

In [None]:

def check_accuracy(loader, model):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()
    iou = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            softmax = nn.Softmax(dim=1)
            preds = torch.argmax(softmax(model(x)),axis=1)
            # preds = model(x)
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            intersection = torch.logical_and(preds, y).sum()
            union = torch.logical_or(preds, y).sum()
            iou += (intersection + 1e-8) / (union + 1e-8)

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")
    print(f"IoU: {iou/len(loader)}")
    model.train()



In [None]:
print("Train")
check_accuracy(train_batch, model)

print("Test")
check_accuracy(test_batch, model)

In [None]:

from skimage import color

for x,y in test_batch:
    x = x.to(DEVICE)
    fig , ax =  plt.subplots(3, 4, figsize=(72, 72))
    softmax = nn.Softmax(dim=1)
    preds = torch.argmax(softmax(model(x)),axis=1).to('cpu')
    img1 = np.transpose(np.array(x[0,:,:,:].to('cpu')),(1,2,0))
    preds1 = np.array(preds[0,:,:])
    mask1 = np.array(y[0,:,:])
    print("Pred1", np.unique(preds1, return_counts=True))
    print("Mask1", np.unique(mask1, return_counts=True))
    img2 = np.transpose(np.array(x[1,:,:,:].to('cpu')),(1,2,0))
    preds2 = np.array(preds[1,:,:])
    mask2 = np.array(y[1,:,:])
    # print("Pred2", np.unique(preds2))
    print("Pred2", np.unique(preds2, return_counts=True))
    print("Mask2", np.unique(mask2, return_counts=True))
    img3 = np.transpose(np.array(x[2,:,:,:].to('cpu')),(1,2,0))
    preds3 = np.array(preds[2,:,:])
    mask3 = np.array(y[2,:,:])
    # print("Pred3", np.unique(preds3))
    print("Pred3", np.unique(preds3, return_counts=True))
    print("Mask3", np.unique(mask3, return_counts=True))
    ax[0,0].set_title('Image')
    ax[0,1].set_title('Prediction')
    ax[0,2].set_title('Mask')
    # ax[0,3].set_title('Overlay')
    ax[1,0].set_title('Image')
    ax[1,1].set_title('Prediction')
    ax[1,2].set_title('Mask')
    # ax[1,3].set_title('Overlay')
    ax[2,0].set_title('Image')
    ax[2,1].set_title('Prediction')
    ax[2,2].set_title('Mask')
    # ax[2,3].set_title('Overlay')
    ax[0][0].axis("off")
    ax[1][0].axis("off")
    ax[2][0].axis("off")
    ax[0][1].axis("off")
    ax[1][1].axis("off")
    ax[2][1].axis("off")
    ax[0][2].axis("off")
    ax[1][2].axis("off")
    ax[2][2].axis("off")
    ax[2][3].axis("off")
    ax[0][0].imshow(img1)
    ax[0][1].imshow(preds1)
    ax[0][2].imshow(mask1)
    # ax[0][3].imshow(color.label2rgb(mask1, img1, saturation=1, alpha=0.5, bg_color=None))

    ax[1][0].imshow(img2)
    ax[1][1].imshow(preds2)
    ax[1][2].imshow(mask2)
    ax[2][0].imshow(img3)
    ax[2][1].imshow(preds3)
    ax[2][2].imshow(mask3)
    break