In [1]:
r""" Visual Prompt Encoder training (validation) code """

import argparse
import os
import pdb

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from common import utils
from common.evaluation import Evaluator
from common.logger import AverageMeter, Logger
from data.dataset import FSSDataset
from model.VRP_encoder_SEN import VRP_encoder_SEN, build_SEN
from SAM2pred import SAM_pred

import model.DETR.util.misc as detr_utils
import sys
import math

In [2]:
import sys
import os

# Set the environment variables for distributed training
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'


sys.argv = ['run',
    '--datapath', '.',
    '--logpath', 'gso_train_detr',
    '--benchmark', 'gso_detr',
    '--backbone', 'resnet50',
    '--fold', '0',
    '--condition', 'mask',
    '--num_query', '50',
    '--epochs', '1',
    '--lr', '1e-4',
    '--bsz', '4',
    '--local_rank', '0',
    # '--load_weight', "checkpoints/gso_train1/best_model_ep14.ptrom",
    '--sam_weight', "/home/icetenny/senior-1/segment-anything/model/sam_vit_h_4b8939.pth"
]


# Arguments parsing
parser = argparse.ArgumentParser(
    description="Visual Prompt Encoder Pytorch Implementation"
)
parser.add_argument("--datapath", type=str, default="ice")
parser.add_argument(
    "--benchmark",
    type=str,
    default="coco",
    choices=["pascal", "coco", "fss", "gso", "gso_detr"],
)
parser.add_argument("--logpath", type=str, default="")
parser.add_argument(
    "--bsz", type=int, default=2
)  # batch size = num_gpu * bsz default num_gpu = 4
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=1e-6)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--nworker", type=int, default=8)
parser.add_argument("--seed", type=int, default=321)
parser.add_argument("--fold", type=int, default=0, choices=[0, 1, 2, 3])
parser.add_argument(
    "--condition",
    type=str,
    default="scribble",
    choices=["point", "scribble", "box", "mask"],
)
parser.add_argument(
    "--use_ignore",
    type=bool,
    default=True,
    help="Boundaries are not considered during pascal training",
)
parser.add_argument(
    "--local_rank",
    type=int,
    default=0,
    help="number of cpu threads to use during batch generation",
)
parser.add_argument("--num_query", type=int, default=50)
parser.add_argument(
    "--backbone",
    type=str,
    default="resnet50",
    choices=["vgg16", "resnet50", "resnet101"],
)

parser.add_argument(
    "--nshot",
    type=int,
    default=14,
)

parser.add_argument(
    "--load_weight",
    type=str,
    default="",
)

parser.add_argument(
    "--sam_weight",
    type=str,
    default="/home/icetenny/senior-1/segment-anything/model/sam_vit_h_4b8939.pth",
)


# DETR

parser.add_argument('--clip_max_norm', default=0.1, type=float,
                        help='gradient clipping max norm')

# * Segmentation
parser.add_argument('--masks', action='store_true',
                    help="Train segmentation head if the flag is provided")

# Loss
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
                    help="Disables auxiliary decoding losses (loss at each layer)")
# * Matcher
parser.add_argument('--set_cost_class', default=1, type=float,
                    help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float,
                    help="L1 box coefficient in the matching cost")
parser.add_argument('--set_cost_giou', default=2, type=float,
                    help="giou box coefficient in the matching cost")


# * Loss coefficients
parser.add_argument('--mask_loss_coef', default=1, type=float)
parser.add_argument('--dice_loss_coef', default=1, type=float)
parser.add_argument('--bbox_loss_coef', default=5, type=float)
parser.add_argument('--giou_loss_coef', default=2, type=float)
parser.add_argument('--eos_coef', default=0.1, type=float,
                    help="Relative classification weight of the no-object class")


args = parser.parse_args()
print(args)


Namespace(datapath='.', benchmark='gso_detr', logpath='gso_train_detr', bsz=4, lr=0.0001, weight_decay=1e-06, epochs=1, nworker=8, seed=321, fold=0, condition='mask', use_ignore=True, local_rank=0, num_query=50, backbone='resnet50', nshot=14, load_weight='', sam_weight='/home/icetenny/senior-1/segment-anything/model/sam_vit_h_4b8939.pth', clip_max_norm=0.1, masks=False, aux_loss=True, set_cost_class=1, set_cost_bbox=5, set_cost_giou=2, mask_loss_coef=1, dice_loss_coef=1, bbox_loss_coef=5, giou_loss_coef=2, eos_coef=0.1)


In [None]:
# def train_nshot(
#     args, epoch, model, dataloader, criterion, optimizer, scheduler, nshot=5
# ):
#     r"""Train VRP_encoder model"""

#     # pdb.set_trace()
#     utils.fix_randseed(args.seed + epoch)
#     model.module.train_mode()
#     criterion.train()

#     # average_meter = AverageMeter(dataloader.dataset)

#     for idx, batch in enumerate(dataloader):

#         batch = utils.to_cuda(batch)
#         # protos = model.module.forward_nshot(
#         #     args.condition,
#         #     batch["query_img"],
#         #     batch["support_imgs"],
#         #     batch["support_masks"],
#         #     training,
#         #     nshot=nshot,
#         # )

#         outputs = model.module.forward(
#             args.condition,
#             batch["query_img"],
#             batch["support_imgs"][:, 0],
#             batch["support_masks"][:, 0],
#             training=True,
#             # nshot=nshot,
#         )

#         loss_dict = criterion(outputs, targets)
#         weight_dict = criterion.weight_dict
#         losses = sum(
#             loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict
#         )

#         # reduce losses over all GPUs for logging purposes
#         loss_dict_reduced = utils.reduce_dict(loss_dict)
#         loss_dict_reduced_unscaled = {
#             f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
#         }
#         loss_dict_reduced_scaled = {
#             k: v * weight_dict[k]
#             for k, v in loss_dict_reduced.items()
#             if k in weight_dict
#         }
#         losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

#         loss_value = losses_reduced_scaled.item()

#         if not math.isfinite(loss_value):
#             print("Loss is {}, stopping training".format(loss_value))
#             print(loss_dict_reduced)
#             sys.exit(1)

#         optimizer.zero_grad()
#         losses.backward()
#         if max_norm > 0:
#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
#         optimizer.step()

#         metric_logger.update(
#             loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled
#         )
#         metric_logger.update(class_error=loss_dict_reduced["class_error"])
#         metric_logger.update(lr=optimizer.param_groups[0]["lr"])

#         loss = model.module.compute_objective(logit_mask, batch["query_mask"])

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         scheduler.step()
#         # print(loss)

#         area_inter, area_union = Evaluator.classify_prediction(
#             pred_mask.squeeze(1), batch
#         )
#         average_meter.update(
#             area_inter, area_union, batch["class_id"], loss.detach().clone()
#         )
#         average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=200)

#     average_meter.write_result("Training" if training else "Validation", epoch)
#     avg_loss = utils.mean(average_meter.loss_buf)
#     miou, fb_iou = average_meter.compute_iou()

#     return avg_loss, miou, fb_iou

In [None]:
def train_nshot(
    args, epoch, model, dataloader, criterion, optimizer, scheduler, nshot=5, max_norm=0
):
    r"""Train VRP_encoder model"""

    # pdb.set_trace()
    utils.fix_randseed(args.seed + epoch)
    model.module.train_mode()
    criterion.train()

    metric_logger = detr_utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', detr_utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('class_error', detr_utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    # for idx, batch in enumerate(dataloader):

    for batch in metric_logger.log_every(dataloader, print_freq, header):
        print("Batch----------------------")

        batch = utils.to_cuda(batch)

        outputs = model.module.forward(
            args.condition,
            batch["query_img"],
            batch["support_imgs"][:, 0],
            batch["support_masks"][:, 0],
            training=True,
            # nshot=nshot,
        )

        print(batch['query_img'].shape)

        print(outputs["pred_logits"].shape, outputs["pred_boxes"].shape)

        targets = []

        for bbox, label in zip(batch['bboxes'], batch['unique_obj_id']):
            target = {}
            target['boxes'] = torch.Tensor(bbox)
            target["labels"] = torch.zeros_like(torch.Tensor(label)).to(torch.int64)
            targets.append(utils.to_cuda(target))

        print(targets)


        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(
            loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict
        )
        print("Losses", losses)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {
            f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
        }
        loss_dict_reduced_scaled = {
            k: v * weight_dict[k]
            for k, v in loss_dict_reduced.items()
            if k in weight_dict
        }
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        metric_logger.update(
            loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled
        )
        metric_logger.update(class_error=loss_dict_reduced["class_error"])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        # loss = model.module.compute_objective(logit_mask, batch["query_mask"])

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        scheduler.step()
        # print(loss)

        metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])


     # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


In [4]:
dist.init_process_group(backend="nccl")

local_rank = dist.get_rank()
print("Num cuda", torch.cuda.device_count(), "Local Rank", local_rank)

# local_rank = args.local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)

if utils.is_main_process():
    Logger.initialize(args, training=True)
utils.fix_randseed(args.seed)


|             datapath: .                       
|            benchmark: gso_detr                
|              logpath: gso_train_detr          
|                  bsz: 4                       
|                   lr: 0.0001                  
|         weight_decay: 1e-06                   
|               epochs: 1                       
|              nworker: 8                       
|                 seed: 321                     
|                 fold: 0                       
|            condition: mask                    
|           use_ignore: True                    
|           local_rank: 0                       
|            num_query: 50                      
|             backbone: resnet50                
|                nshot: 14                      
|          load_weight:                         
|           sam_weight: /home/icetenny/senior-1/segment-anything/model/sam_vit_h_4b8939.pth
|        clip_max_norm: 0.1                     
|                masks: F

Num cuda 1 Local Rank 0


In [5]:
# Model initialization
model, criterion = build_SEN(args=args, device=device)


if utils.is_main_process():
    Logger.log_params(model)


# Load Weight
if args.load_weight != "":
    if os.path.exists(args.load_weight):
        model.load_state_dict(torch.load(args.load_weight, map_location=device))
        print(f"Model loaded from {args.load_weight}")
    else:
        print(f"No saved model found at {args.load_weight}")


Backbone # param.: 23685367
Learnable # param.: 1725190
Total # param.: 25410557


In [6]:
# Dataset initialization
FSSDataset.initialize(
    img_size=512, datapath=args.datapath, use_original_imgsize=False
)
dataset, dataloader_trn = FSSDataset.build_dataloader(
    args.benchmark, args.bsz, args.nworker, args.fold, "trn", shot=args.nshot
)

# dataloader_val = FSSDataset.build_dataloader(
#     args.benchmark, args.bsz, args.nworker, args.fold, "val"
# )


Total (trn) Images are : 10000
Total (trn) Query are : 87095


In [8]:
i=30
for batch in dataloader_trn:
    print(batch['query_img'].shape, batch['support_masks'].shape, batch['support_imgs'].shape, batch["bboxes"], batch['unique_obj_id'])
    i-=1

    if i==0:
        break

torch.Size([4, 3, 512, 512]) torch.Size([4, 14, 512, 512]) torch.Size([4, 14, 3, 512, 512]) [[[0.330556, 0.715741, 0.213889, 0.246296]], [[0.459028, 0.880556, 0.495833, 0.238889]], [[0.200694, 0.627778, 0.104167, 0.074074]], [[0.494444, 0.503704, 0.280556, 0.374074]]] [[9], [8], [12], [15]]
torch.Size([4, 3, 512, 512]) torch.Size([4, 14, 512, 512]) torch.Size([4, 14, 3, 512, 512]) [[[0.495833, 0.559259, 0.222222, 0.166667]], [[0.293056, 0.655556, 0.455556, 0.474074]], [[0.488889, 0.496296, 0.1, 0.162963]], [[0.258333, 0.281481, 0.433333, 0.562963]]] [[9], [3], [7], [1]]
torch.Size([4, 3, 512, 512]) torch.Size([4, 14, 512, 512]) torch.Size([4, 14, 3, 512, 512]) [[[0.416667, 0.167593, 0.188889, 0.335185]], [[0.439583, 0.623148, 0.579167, 0.735185]], [[0.509722, 0.831481, 0.230556, 0.337037]], [[0.209722, 0.530556, 0.313889, 0.538889]]] [[17], [3], [3], [10]]
torch.Size([4, 3, 512, 512]) torch.Size([4, 14, 512, 512]) torch.Size([4, 14, 3, 512, 512]) [[[0.209722, 0.412963, 0.163889, 0.3962

In [8]:
batch['query_img'].shape

torch.Size([4, 3, 512, 512])

In [7]:
optimizer = optim.AdamW(
    [
        {"params": model.module.transformer_decoder.parameters()},
        {"params": model.module.downsample_query.parameters(), "lr": args.lr},
        {"params": model.module.merge_1.parameters(), "lr": args.lr},
    ],
    lr=args.lr,
    weight_decay=args.weight_decay,
    betas=(0.9, 0.999),
)
Evaluator.initialize(args)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=args.epochs * len(dataloader_trn)
)


In [8]:

# Training
best_val_miou = float("-inf")
best_val_loss = float("inf")
for epoch in range(args.epochs):

    trn_loss, trn_miou, trn_fb_iou = train_nshot(
        args=args,
        epoch=epoch,
        model=model,
        dataloader=dataloader_trn,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        nshot=args.nshot,
        max_norm=args.clip_max_norm
    )
    # with torch.no_grad():
    #     val_loss, val_miou, val_fb_iou = train_nshot(
    #         args,
    #         epoch,
    #         model,
    #         sam_model,
    #         dataloader_val,
    #         optimizer,
    #         scheduler,
    #         training=False,
    #     )

    # Save the best model
    if trn_miou > best_val_miou:
        best_val_miou = trn_miou
        if utils.is_main_process():
            Logger.save_model_miou(model, epoch, trn_miou)
    if utils.is_main_process():
        Logger.tbd_writer.add_scalars(
            "data/loss", {"trn_loss": trn_loss, "val_loss": 0}, epoch
        )
        Logger.tbd_writer.add_scalars(
            "data/miou", {"trn_miou": trn_miou, "val_miou": 0}, epoch
        )
        Logger.tbd_writer.add_scalars(
            "data/fb_iou",
            {"trn_fb_iou": trn_fb_iou, "val_fb_iou": 0},
            epoch,
        )
        Logger.tbd_writer.flush()
Logger.tbd_writer.close()
Logger.info("==================== Finished Training ====================")

torch.Size([4, 50, 256])
torch.Size([4, 3, 512, 512])
torch.Size([4, 50, 2]) torch.Size([4, 50, 4])
[{'boxes': tensor([[0.3306, 0.7157, 0.2139, 0.2463]], device='cuda:0'), 'labels': tensor([0], device='cuda:0')}, {'boxes': tensor([[0.4590, 0.8806, 0.4958, 0.2389]], device='cuda:0'), 'labels': tensor([0], device='cuda:0')}, {'boxes': tensor([[0.2007, 0.6278, 0.1042, 0.0741]], device='cuda:0'), 'labels': tensor([0], device='cuda:0')}, {'boxes': tensor([[0.4944, 0.5037, 0.2806, 0.3741]], device='cuda:0'), 'labels': tensor([0], device='cuda:0')}]
OWWO
torch.Size([200, 2]) torch.Size([200, 4])
tensor([0, 0, 0, 0], device='cuda:0') torch.Size([4]) torch.int64
tensor([[0.3306, 0.7157, 0.2139, 0.2463],
        [0.4590, 0.8806, 0.4958, 0.2389],
        [0.2007, 0.6278, 0.1042, 0.0741],
        [0.4944, 0.5037, 0.2806, 0.3741]], device='cuda:0') torch.Size([4, 4]) torch.float32
Indice [(tensor([24]), tensor([0])), (tensor([24]), tensor([0])), (tensor([35]), tensor([0])), (tensor([22]), tensor([0

IndexError: too many indices for tensor of dimension 1

In [13]:
targets = [{'boxes': [[0.330556, 0.715741, 0.213889, 0.246296]], 'labels': [9]}, {'boxes': [[0.459028, 0.880556, 0.495833, 0.238889]], 'labels': [8]}, {'boxes': [[0.200694, 0.627778, 0.104167, 0.074074]], 'labels': [12]}, {'boxes': [[0.494444, 0.503704, 0.280556, 0.374074]], 'labels': [15]}]

for t in targets:
    t['boxes'] = torch.tensor(t['boxes'])
    t['labels'] = torch.tensor(t['labels'])
targets

[{'boxes': tensor([[0.3306, 0.7157, 0.2139, 0.2463]]), 'labels': tensor([9])},
 {'boxes': tensor([[0.4590, 0.8806, 0.4958, 0.2389]]), 'labels': tensor([8])},
 {'boxes': tensor([[0.2007, 0.6278, 0.1042, 0.0741]]), 'labels': tensor([12])},
 {'boxes': tensor([[0.4944, 0.5037, 0.2806, 0.3741]]), 'labels': tensor([15])}]

In [11]:
tgt_bbox = torch.cat([v["boxes"] for v in targets])
tgt_bbox

tensor([[0.3306, 0.7157, 0.2139, 0.2463],
        [0.4590, 0.8806, 0.4958, 0.2389],
        [0.2007, 0.6278, 0.1042, 0.0741],
        [0.4944, 0.5037, 0.2806, 0.3741]])

In [15]:

tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_ids

tensor([ 9,  8, 12, 15])

In [9]:
import torch

out_prob = torch.tensor([
    [0.1, 0.6],  # Probabilities for query 1
    [0.2, 0.2],  # Probabilities for query 2
    [0.7, 0.1],  # Probabilities for query 3
    [0.1, 0.6],  # Probabilities for query 1
    [0.2, 0.2],  # Probabilities for query 2
    [0.7, 0.1],  # Probabilities for query 3
])  # Shape: [3, 3] (3 queries, 3 classes)

tgt_ids = torch.tensor([0, 0,0,0,0]).to(torch.int64)  # Ground truth classes for two targets

out_prob.shape, tgt_ids.shape, tgt_ids.dtype

(torch.Size([6, 2]), torch.Size([5]), torch.int64)

In [10]:
cost_class = -out_prob[:, tgt_ids]
cost_class

tensor([[-0.1000, -0.1000, -0.1000, -0.1000, -0.1000],
        [-0.2000, -0.2000, -0.2000, -0.2000, -0.2000],
        [-0.7000, -0.7000, -0.7000, -0.7000, -0.7000],
        [-0.1000, -0.1000, -0.1000, -0.1000, -0.1000],
        [-0.2000, -0.2000, -0.2000, -0.2000, -0.2000],
        [-0.7000, -0.7000, -0.7000, -0.7000, -0.7000]])