In [4]:
%tb
import time
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import logging
from torch.nn import functional as F
import torch.optim as optim
from tqdm import tqdm
from DataLoader import RoadDataset
import argparse
from cfg import parse_args
from segment_anything import SamPredictor, sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from utils import FocalDiceloss_IoULoss, SegMetrics
import einops
 

ckeckpoint_dir = '../workdir/ck/SAM_checkpoint/sam_vit_b_01ec64.pth'
device = torch.device("mps")
train_root = "../../graduating_project/Dataset/DeepGlobeRoadExtraction/road/train/"

# model_save_path = os.path.join(args.work_dir, args.run_name)
# os.makedirs(model_save_path, exist_ok=True)
dataset = RoadDataset(data_root=train_root, image_size=512, train=True, box = True, points=True, transform=True)
train_dataloader = DataLoader(dataset, 2, True)
model = sam_model_registry['vit_b'](checkpoint=ckeckpoint_dir).to(device)

for i, batch in enumerate(train_dataloader):
    image = batch['image'].to(device)
    mask = batch['mask'].to(device)

    box = batch['box']
    point_coords = batch["point_coords"]
    point_labels = batch["point_labels"]
    break
# sam_trans = ResizeLongestSide(model.image_encoder.img_size)
# box = np.asarray(box)
# box = sam_trans.apply_boxes(box, (mask.shape[-2], mask.shape[-1]))
# box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
# if len(box_torch.shape) == 2:
#                 box_torch = box_torch[:, None, :] # (B, 1, 4)

for n, value in model.image_encoder.named_parameters():
    if "Adapter" in n:
           value.requires_grad = True
    else:
            value.requires_grad = False

image_embedding = model.image_encoder(image)

coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)

points = (coords_torch, labels_torch)
with torch.no_grad():
                sparse_embeddings, dense_embeddings = model.prompt_encoder(
                points = points,
                boxes = box.to(device),
                masks = None)

                pred_masks, iou_pred = model.mask_decoder(
                    image_embeddings = image_embedding.to(device),
                    image_pe = model.prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings = sparse_embeddings,
                    dense_prompt_embeddings = dense_embeddings,
                    multimask_output = False
                )
pred = F.interpolate(pred_masks,size=(512,512))

seg_loss = FocalDiceloss_IoULoss()
mask = einops.repeat(mask, "b h w -> b 1 h w")
loss = seg_loss(pred, mask, iou_pred)

metric = SegMetrics(pred, mask, ['iou', 'dice'])
print(f"IOU: {metric[0]}, DICE: {metric[1]}")
print(f"loss: {loss}")
print(f"shape of sparse_embeddings {sparse_embeddings.shape}")
print(f"shape of dense_embeddings {dense_embeddings.shape}")
print(f"shape of image embedding {image_embedding.shape}")
print(f"shape of pred_masks {pred.shape}")
print(f"shape of masks {mask.shape}")
print(f"shape of image {image.shape}")
print(f"shape of box {box.shape}")

AttributeError: 'tuple' object has no attribute 'to'

*******interpolate
*******load ../workdir/ck/SAM_checkpoint/sam_vit_b_01ec64.pth


  'image': torch.tensor(image).float(),
  'mask': torch.tensor(mask).float(),


IOU: 0.05095195397734642, DICE: 0.09674651175737381
loss: -1.120917797088623
shape of sparse_embeddings torch.Size([2, 7, 256])
shape of dense_embeddings torch.Size([2, 256, 32, 32])
shape of image embedding torch.Size([2, 256, 32, 32])
shape of pred_masks torch.Size([2, 1, 512, 512])
shape of masks torch.Size([2, 1, 512, 512])
shape of image torch.Size([2, 3, 512, 512])
shape of box torch.Size([2, 4])


In [1]:
from utils import generate_point

for i, batch in enumerate(train_dataloader):
    image = batch['image'].to(device)
    mask = batch['mask'].to(device)


points = generate_point(pred, mask, pred_masks, batch, 5)

with torch.no_grad():
                sparse_embeddings, dense_embeddings = model.prompt_encoder(
                points = points,
                boxes = box.to(device),
                masks = None)

                pred_masks, iou_pred = model.mask_decoder(
                    image_embeddings = image_embedding.to(device),
                    image_pe = model.prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings = sparse_embeddings,
                    dense_prompt_embeddings = dense_embeddings,
                    multimask_output = False
                )

pred = F.interpolate(pred_masks,size=(1024,1024))
seg_loss = FocalDiceloss_IoULoss()
mask = einops.repeat(mask, "b h w -> b 1 h w")
loss = seg_loss(pred, mask, iou_pred)

metric = SegMetrics(pred, mask, ['iou', 'dice'])
print(metric)
print(f"loss: {loss}")
print(f"shape of sparse_embeddings {sparse_embeddings.shape}")
print(f"shape of dense_embeddings {dense_embeddings.shape}")
print(f"shape of image embedding {image_embedding.shape}")
print(f"shape of pred_masks {pred.shape}")
print(f"shape of masks {mask.shape}")
print(f"shape of image {image.shape}")
print(f"shape of box {box.shape}")
print(f"shape of points {points.shape}")

NameError: name 'train_dataloader' is not defined

In [2]:
optimizer = torch.optim.Adam(model.mask_decoder.parameters(), lr=0.001, weight_decay=0)
loss = seg_loss(pred, mask, iou_pred)
model_save_path = os.path.join('../workdir/', 'sam-satellite')


In [7]:
losses = []
model.train()
for epoch in range(2):
    epoch_loss = 0
    # Just train on the first 20 examples
    for step, batch in enumerate(tqdm(train_dataloader)):

        image = batch['image'].to(device)
        mask = batch['mask'].to(device)

        if 'box' in batch:
            box = batch['box']

        for name, param in model.named_parameters():
                param.requires_grad = "image_encoder" not in name
        
        image_embedding = model.image_encoder(image)

        # do not compute gradients for image encoder and prompt encoder
        with torch.no_grad():
            # convert box to 1024x1024 grid
            sparse_embeddings, dense_embeddings = model.prompt_encoder(
                    points = None,
                    boxes = box.to(device),
                    masks = None)
            
        pred_masks, iou_pred = model.mask_decoder(
                    image_embeddings = image_embedding.to(device),
                    image_pe = model.prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings = sparse_embeddings,
                    dense_prompt_embeddings = dense_embeddings,
                    multimask_output = False
            )       
        
        pred = F.interpolate(pred_masks,size=(1024, 1024))

            # to add the mask to the loss function it need to be shape of [B, 1, H, W] (1)
            # but it in shape of [B, H, W] (2), so the blow line convert from shape 2 --> 1
        mask = einops.repeat(mask, "b h w -> b 1 h w")
        loss = seg_loss(pred, mask, iou_pred)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss /= step
    losses.append(epoch_loss)
    print(f'EPOCH: {epoch}, Loss: {epoch_loss}')
    # save the model checkpoint
    torch.save(model.state_dict(), os.path.join(model_save_path, 'sam_model_latest.pth'))
    # save the best model
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), os.path.join(model_save_path, 'sam_model_best.pth'))

  2%|▏         | 3/127 [00:14<10:19,  5.00s/it]


KeyboardInterrupt: 

In [8]:
def train_sam(propmt_grad, model, optimizer, criterion, train_loader, device, epoch, no_epochs=2):

    epoch_loss = 0

    model.train()
    for epoch in range(no_epochs):
        with tqdm(total=len(train_loader), desc=f'Epoch {epoch}') as pbar:
                for batch_idx, batch in enumerate(train_loader):

                    image = batch['image'].to(device)
                    mask = batch['mask'].to(device)

                    if 'box' in batch:
                        box = batch['box']

                    optimizer.zero_grad()

                    for name, param in model.named_parameters():
                        param.requires_grad = "image_encoder" not in name

                    image_embedding = model.image_encoder(image)

                    if propmt_grad is True:
                        sparse_embeddings, dense_embeddings = model.prompt_encoder(
                        points = None,
                        boxes = box.to(device),
                        masks = None
                    )
                    else:
                        with torch.no_grad():
                            sparse_embeddings, dense_embeddings = model.prompt_encoder(
                            points = None,
                            boxes = box.to(device),
                            masks = None
                    )
                    pred_masks, iou_pred = model.mask_decoder(
                            image_embeddings = image_embedding.to(device),
                            image_pe = model.prompt_encoder.get_dense_pe(),
                            sparse_prompt_embeddings = sparse_embeddings,
                            dense_prompt_embeddings = dense_embeddings,
                            multimask_output = False
                    )       
                    
                    # The pred_masks in shape of [B, C, 256, 256] we need to upsample or downsample to the img input size 
                    # So the below line will ## Resize to the ordered output size
                    pred = F.interpolate(pred_masks,size=(1024, 1024))

                    # to add the mask to the loss function it need to be shape of [B, 1, H, W] (1)
                    # but it in shape of [B, H, W] (2), so the blow line convert from shape 2 --> 1
                    mask = einops.repeat(mask, "b h w -> b 1 h w")

                    loss = criterion(pred, mask, iou_pred)

                    epoch_loss += loss.item()
                    
                    loss.backward()
                    optimizer.step()

                    pbar.set_postfix({'loss': epoch_loss / (batch_idx + 1)})
                    pbar.update(1)
                    
        epoch_loss /= len(train_dataloader)
        print(f'Epoch {epoch}/{1}, Loss: {epoch_loss:.4f}')

dataset = RoadDataset(data_root=train_root, image_size=1024, train=True, box = True, transform=True)
train_dataloader = DataLoader(dataset, 2, True)
model = sam_model_registry['vit_b'](checkpoint=ckeckpoint_dir).to(device)
train_sam(True, model, optimizer, seg_loss, train_dataloader, device, 1)

Epoch 0: 100%|██████████| 127/127 [08:05<00:00,  3.82s/it, loss=-144]


Epoch 0/1, Loss: -144.0107


Epoch 1: 100%|██████████| 127/127 [08:03<00:00,  3.80s/it, loss=-147]

Epoch 1/1, Loss: -147.1673





In [6]:
print(len(train_dataloader))

708
