### Read dependencies and install them

In [1]:
import os
import sys
import subprocess

repos = {
    "lightning-sam": "https://github.com/luca-medeiros/lightning-sam.git"
}

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
scripts_dir = os.path.join(parent_dir, 'Scripts', 'libs')

sys.path.append(os.path.dirname(scripts_dir))
print(f"Scripts directory: {scripts_dir}")

os.makedirs(scripts_dir, exist_ok=True)

for repo_name, repo_url in repos.items():
    repo_path = os.path.join(scripts_dir, repo_name)

    if os.path.isdir(repo_path):
        print(f"Pulling the latest version of {repo_name}...")
        subprocess.run(['git', 'pull'], cwd=repo_path, check=True)
    else:
        print(f"Cloning the repository {repo_name}...")
        subprocess.run(['git', 'clone', repo_url, repo_path], check=True)

    for root, dirs, files in os.walk(repo_path):
        for dir in dirs:
            full_path = os.path.join(root, dir)
            if full_path not in sys.path:
                sys.path.append(full_path)


    # Install the dependencies
    requirements_path = os.path.join(repo_path, 'requirements.txt')
    if os.path.isfile(requirements_path):
        print(f"Installing dependencies for {repo_name}...")
        subprocess.run(['pip', 'install', '-r', requirements_path], check=True)

Scripts directory: i:\My Drive\Flow_segmentation\Scripts\libs
Pulling the latest version of lightning-sam...


### Training setup

In [4]:
from Fine_tune_config import ft_cfg
import pandas as pd

pd.set_option('display.max_colwidth', None)
flat_config = pd.json_normalize(ft_cfg, sep='_')
config_df = flat_config.transpose()
config_df.columns = ['Value']
print(config_df)

                                                                                               Value
num_devices                                                                                        0
batch_size                                                                                         1
num_workers                                                                                        2
num_epochs                                                                                         1
eval_interval                                                                                      2
out_dir                                         i:\My Drive\Flow_segmentation\Checkpoints_models\out
opt_learning_rate                                                                             0.0008
opt_weight_decay                                                                              0.0001
opt_decay_factor                                                                           

### Train the model

In [None]:
import os
import lightning as L
from model import Model
from dataset import load_datasets
import torch.nn.functional as F
from train import train_sam, validate, configure_opt
from lightning.fabric.loggers import TensorBoardLogger


# fabric = L.Fabric(accelerator="auto",
#                   devices=ft_cfg.num_devices,
#                   strategy="auto",
#                   loggers=[TensorBoardLogger(ft_cfg.out_dir, name="lightning-sam")])

accelerator = "cpu" if ft_cfg.num_devices == 0 else "auto"

fabric = L.Fabric(
    accelerator=accelerator,
    devices=ft_cfg.num_devices,
    strategy="auto",
    loggers=[TensorBoardLogger(ft_cfg.out_dir, name="lightning-sam")]
)

fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)

if fabric.global_rank == 0:
    os.makedirs(ft_cfg.out_dir, exist_ok=True)

with fabric.device:
    model = Model(ft_cfg)
    model.setup()

train_data, val_data = load_datasets(ft_cfg, model.model.image_encoder.img_size)
train_data = fabric._setup_dataloader(train_data)
val_data = fabric._setup_dataloader(val_data)

optimizer, scheduler = configure_opt(ft_cfg, model)
model, optimizer = fabric.setup(model, optimizer)

train_sam(ft_cfg, fabric, model, optimizer, scheduler, train_data, val_data)
validate(fabric, model, val_data, epoch=0)

### Use the fine-tuned model 

In [None]:
import os
import cv2
import torch
from box import Box
from dataset import COCODataset
from model import Model
from tqdm import tqdm
from utils import draw_image

def visualize(cfg: Box):
    model = Model(cfg)
    model.setup()
    model.eval()
    model.cuda()

    dataset = COCODataset(root_dir=cfg.dataset.val.root_dir,
                          annotation_file=cfg.dataset.val.annotation_file,
                          transform=None)
    
    predictor = model.get_predictor()
    os.makedirs(cfg.out_dir, exist_ok=True)

    for image_id in tqdm(dataset.image_ids):
        image_info = dataset.coco.loadImgs(image_id)[0]
        image_path = os.path.join(dataset.root_dir, image_info['file_name'])
        image_output_path = os.path.join(cfg.out_dir, image_info['file_name'])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ann_ids = dataset.coco.getAnnIds(imgIds=image_id)
        anns = dataset.coco.loadAnns(ann_ids)
        bboxes = []

        for ann in anns:
            x, y, w, h = ann['bbox']
            bboxes.append([x, y, x + w, y + h])
        bboxes = torch.as_tensor(bboxes, device=model.model.device)
        transformed_boxes = predictor.transform.apply_boxes_torch(bboxes, image.shape[:2])
        predictor.set_image(image)
        masks, _, _ = predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,
        )

        image_output = draw_image(image, masks.squeeze(1), boxes=None, labels=None)
        cv2.imwrite(image_output_path, image_output)


visualize(ft_cfg)