In [1]:
import sys
sys.path.append('../src')
from models import *
from strategies import *
from custom_datasets import *
import numpy as np
np.random.seed(0)
import tqdm

import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import os
import torch
torch.cuda.empty_cache()
import torch.nn as nn

from PIL import Image
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
import pandas as pd

import time
import json
import wandb
from unet_model import *


In [2]:
main_path = "/root/Master_Thesis/"
dataframes_path = main_path + "data/dataframes/"
sam_path = main_path + "sam/sam_vit_h_4b8939.pth"
notebooks_path = main_path + "scripts/notebooks/"
expirements_path = main_path+"expirements/"

In [3]:
# df_name = "brain_df"
# train_df = pd.read_csv(dataframes_path+"brain_df_train.csv")
# test_df = pd.read_csv(dataframes_path+"brain_df_test.csv")

# df_name = "lung_tumor_df"
# train_df = pd.read_csv(dataframes_path+"lung_df_train.csv")
# test_df = pd.read_csv(dataframes_path+"lung_df_test.csv")

df_name = "lunar_df"
train_df = pd.read_csv(dataframes_path+"lunar_df_train.csv")
test_df = pd.read_csv(dataframes_path+"lunar_df_test.csv")

In [4]:
len(test_df)

2930

In [5]:
# with open('params.json') as f:
#     params = json.load(f)
#     print(params)

In [6]:
len(train_df)

6836

In [7]:
params = {'n_epoch': 35,
          'train_args':{'batch_size': 4, 'num_workers': 1},
          'test_args':{'batch_size': 256, 'num_workers': 1},
          'optimizer_args':{'lr': 5e-3, 'momentum': 0.9},
          'use_sam': False,
          'use_predictor': False,
          'use_generator': False,
          'init_set_size': 300,
          'rounds': 30,
          "activate_sam_at_round":1, 
          "img_size":(128, 128),
          "voting" : False,
          "pre_trained": True,
          "dataset": "Lunar_Rocky_Landscape"}

if params["init_set_size"] == len(train_df):
    params["training_type"] = "no_active"
elif not params["use_sam"]:
    params["training_type"] = "no_sam"
elif params["voting"]:
    params["training_type"] = "voters"
else:
    params["training_type"]="withSAM_NoVoting"

if params["training_type"] == "no_active":
    if params["pre_trained"]:
        params["model_path"] = f'{notebooks_path}trained_models/{params["dataset"]}/no_active/pre_trained_Unet_{params["img_size"][0]}.pt'
    else:
        params["model_path"] = f'{notebooks_path}trained_models/{params["dataset"]}/no_active/not_pre_trained_Unet_{params["img_size"][0]}.pt'
else:
    params["model_path"] = f'{notebooks_path}trained_models/{params["dataset"]}/{params["training_type"]}/{params["init_set_size"]}'
    
params['test_set_size'] = len(test_df)
params['df'] = df_name
params['query_num'] = int(0.05 * params['init_set_size'])
if params['query_num'] == 0:
    params['query_num'] = 1
params["strategy"] = "MarginSampling"

if params["training_type"] == "voters":
    params["model_path"] = f'{params["model_path"]}/voters_{params["img_size"][0]}'

In [8]:
wandb.init(
    # set the wandb project where this run will be logged
    project=params["dataset"],
    
    notes = f'{params["training_type"]}_{params["init_set_size"]}',

    # track hyperparameters and run metadata
    config=params    
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33msaleemfares1995-sf[0m ([33mthesis_fares[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
if params['use_sam']:
    sam = SAMOracle(checkpoint_path=sam_path, img_size=params["img_size"])
else:
    sam =None

In [10]:
if params["pre_trained"]:
    model = smp.create_model('Unet', encoder_name='resnet34', in_channels=3, classes = 1)

    if not os.path.isfile(notebooks_path+"trained_models/shared_init_state_pre_trained.pt"):
        torch.save(model.state_dict(), notebooks_path+"trained_models/shared_init_state_pre_trained.pt")
    
    init_state_Unet = torch.load(notebooks_path+"trained_models/shared_init_state_pre_trained.pt")
    model.load_state_dict(init_state_Unet)
else:
    model = UNet(n_channels=3, n_classes=1, bilinear=True)

    if not os.path.isfile(notebooks_path+"trained_models/shared_init_state_not_trained.pt"):
        torch.save(model.state_dict(), notebooks_path+"trained_models/shared_init_state_not_trained.pt")

    init_state_Unet = torch.load(notebooks_path+"trained_models/shared_init_state_not_trained.pt")
    model.load_state_dict(init_state_Unet)

In [11]:
init_path = ""

if params["training_type"] == "voters":
    init_path = params["model_path"] + '_0/main_Unet.pt'

elif params["training_type"] == "withSAM_NoVoting":
    init_path = params["model_path"] + '/main_Unet_128_0.pt'
    
elif params["training_type"] == "no_sam":
    init_path = f'{params["model_path"]}/Active_{params["init_set_size"]}_{0}_no_sam_{params["img_size"][0]}_Unet.pt'
    
if len(init_path) > 0:
    if not os.path.isfile(init_path):
        init_dir = os.path.dirname(init_path)
        if not os.path.exists(init_dir):
            os.makedirs(init_dir)
        torch.save(model.state_dict(), init_path)
    init_state_Unet = torch.load(init_path)
    model.load_state_dict(init_state_Unet)
    


net = Net(model, params, device = torch.device("cuda"))

In [12]:
init_path

'/root/Master_Thesis/scripts/notebooks/trained_models/Lunar_Rocky_Landscape/no_sam/300/Active_300_0_no_sam_128_Unet.pt'

In [13]:
def get_data(handler, train_df, test_df):
    return Data(train_df["images"].to_list(), train_df["masks"].to_list(), test_df["images"].to_list(), test_df["masks"].to_list(), handler, img_size=params["img_size"], df=train_df, path= main_path+"/data/processed/", use_sam=params['use_sam'])

In [14]:
data = get_data(Handler, train_df, test_df)
data.initialize_labels(params["init_set_size"])

In [15]:
strategy = MarginSampling(dataset=data, net=net, sam=sam, params=params)
params["strategy"] = "MarginSampling"

In [16]:
params["model_path"]

'/root/Master_Thesis/scripts/notebooks/trained_models/Lunar_Rocky_Landscape/no_sam/300'

In [17]:
print("Round 0")
rd = 1
logs =[]
main_path = ""
if params["training_type"] == "no_sam":
    main_path = f'{params["model_path"]}/Active_{params["init_set_size"]}_{rd}_no_sam_{params["img_size"][0]}_Unet.pt'

elif params["training_type"] == "voters":
    main_path = f'{params["model_path"]}_{rd}/main_Unet.pt'

elif params["training_type"] == "withSAM_NoVoting":
    main_path = f'{params["model_path"]}/main_Unet_{params["img_size"][0]}_{rd}.pt'

if len(main_path)>0:
    if not os.path.isfile(main_path):
        strategy.train()
        main_dir = os.path.dirname(main_path)
        if not os.path.exists(main_dir):
            os.makedirs(main_dir)
        torch.save(strategy.net.net.state_dict(), main_path)
        print("Saved : " + main_path)
    else:
        print(main_path)
        strategy.net.net.load_state_dict(torch.load(main_path))
else:
    main_path = f'{params["model_path"]}'
    main_dir = os.path.dirname(main_path)        
    strategy.train()
    if not os.path.exists(main_dir):
        os.makedirs(main_dir)
    torch.save(strategy.net.net.state_dict(), main_path)
    
logits, mask_gt = strategy.predict(data.get_test_data())
iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
wandb.log({"iou_score" : iou_score, "accuracy" : accuracy, "precision" : precision, "recall" : recall, "f1_score" : f1_score})
logs.append(f"Round 0 testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
print(logs[0])

for rd in range(1, params["rounds"]):
    print(f"Round {rd}")

    # query
    print("Querying")
    query_idxs = strategy.query(params["query_num"])
    print(query_idxs)
    # update labels
    if params["use_sam"] and rd >= params["activate_sam_at_round"]:
        print("Updating with sam")
        if params["training_type"] == "voters":
            masks = strategy.update_voting(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"], round=rd)
        
        elif params["training_type"] == "withSAM_NoVoting":
            masks = strategy.update(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"], round=rd)
        # else:
            # masks = strategy.update_weighted_voting(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"], round=rd)
    else:
        print("Updating without sam")
        strategy.update(query_idxs)
    
    print("Reset and train")
    if params["training_type"] == "no_sam":
        main_path = f'{params["model_path"]}/Active_{params["init_set_size"]}_{rd+1}_no_sam_{params["img_size"][0]}_Unet.pt'

    elif params["training_type"] == "voters":
        main_path = f'{params["model_path"]}_{rd+1}/main_Unet.pt'

    elif params["training_type"] == "withSAM_NoVoting":
        main_path = f'{params["model_path"]}/main_Unet_{params["img_size"][0]}_{rd+1}.pt'

    if not os.path.isfile(main_path):
        strategy.net.net.load_state_dict(init_state_Unet)
        strategy.train()
        main_dir = os.path.dirname(main_path)
        if not os.path.exists(main_dir):
            os.makedirs(main_dir)
        torch.save(strategy.net.net.state_dict(), main_path)
        print("Saved : " + main_path)
    else:
        strategy.net.net.load_state_dict(torch.load(main_path))
    # print("uncomment the block above")
    # strategy.net.net.load_state_dict(init_state_Unet)
    # strategy.train()    

    # calculate accuracy
    logits, maks_gt = strategy.predict(data.get_test_data())
    iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
    wandb.log({"iou_score" : iou_score, "accuracy" : accuracy, "precision" : precision, "recall" : recall, "f1_score" : f1_score})
    logs.append(f"Round {rd} testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
    print(logs[rd])
    
params['logs'] = logs
wandb.finish()

Round 0


100%|███████████████████████████████████████████████████| 35/35 [04:59<00:00,  8.55s/it, loss=0.463]


Saved : /root/Master_Thesis/scripts/notebooks/trained_models/Lunar_Rocky_Landscape/no_sam/300/Active_300_1_no_sam_128_Unet.pt
Round 0 testing metrics: iou_score = 0.40, accuracy = 0.96, precision = 0.68, recall = 0.49, f1_score = 0.57
Round 1
Querying
[4358 4374 4373 4371 4370 4368 4366 4365 4363 4362 4360 4359 4375 4357
 4356]
Updating without sam
Reset and train


100%|███████████████████████████████████████████████████| 35/35 [05:18<00:00,  9.09s/it, loss=0.141]


Saved : /root/Master_Thesis/scripts/notebooks/trained_models/Lunar_Rocky_Landscape/no_sam/300/Active_300_2_no_sam_128_Unet.pt
Round 1 testing metrics: iou_score = 0.37, accuracy = 0.97, precision = 0.77, recall = 0.42, f1_score = 0.54
Round 2
Querying
[3940 3951 3950 3949 3948 6274 3946 3945 3944 3943 3941 6272 3939 6277
 3936]
Updating without sam
Reset and train


100%|███████████████████████████████████████████████████| 35/35 [05:35<00:00,  9.58s/it, loss=0.233]


Saved : /root/Master_Thesis/scripts/notebooks/trained_models/Lunar_Rocky_Landscape/no_sam/300/Active_300_3_no_sam_128_Unet.pt
Round 2 testing metrics: iou_score = 0.41, accuracy = 0.97, precision = 0.73, recall = 0.48, f1_score = 0.58
Round 3
Querying
[4401 4414 4413 4411 4410 4409 4408 4407 4406 4404 4403 4402 4415 4400
 4398]
Updating without sam
Reset and train


100%|███████████████████████████████████████████████████| 35/35 [05:50<00:00, 10.02s/it, loss=0.224]


Saved : /root/Master_Thesis/scripts/notebooks/trained_models/Lunar_Rocky_Landscape/no_sam/300/Active_300_4_no_sam_128_Unet.pt
Round 3 testing metrics: iou_score = 0.40, accuracy = 0.97, precision = 0.74, recall = 0.46, f1_score = 0.57
Round 4
Querying
[4376 4389 4388 4387 4386 4385 4382 4381 4380 4379 4378 4377 4390 4367
 4361]
Updating without sam
Reset and train


100%|███████████████████████████████████████████████████| 35/35 [06:04<00:00, 10.41s/it, loss=0.142]


Saved : /root/Master_Thesis/scripts/notebooks/trained_models/Lunar_Rocky_Landscape/no_sam/300/Active_300_5_no_sam_128_Unet.pt
Round 4 testing metrics: iou_score = 0.39, accuracy = 0.97, precision = 0.78, recall = 0.43, f1_score = 0.56
Round 5
Querying
[4393 4427 4426 4425 4424 4423 4422 4417 4416 4396 4395 4394 4428 4392
 4391]
Updating without sam
Reset and train


100%|███████████████████████████████████████████████████| 35/35 [06:21<00:00, 10.89s/it, loss=0.219]


Saved : /root/Master_Thesis/scripts/notebooks/trained_models/Lunar_Rocky_Landscape/no_sam/300/Active_300_6_no_sam_128_Unet.pt
Round 5 testing metrics: iou_score = 0.42, accuracy = 0.96, precision = 0.64, recall = 0.55, f1_score = 0.59
Round 6
Querying
[4482 4495 4494 4493 4491 4490 4489 4487 4486 4485 4484 4483 4496 4481
 4479]
Updating without sam
Reset and train


100%|███████████████████████████████████████████████████| 35/35 [06:36<00:00, 11.34s/it, loss=0.168]


Saved : /root/Master_Thesis/scripts/notebooks/trained_models/Lunar_Rocky_Landscape/no_sam/300/Active_300_7_no_sam_128_Unet.pt
Round 6 testing metrics: iou_score = 0.41, accuracy = 0.97, precision = 0.75, recall = 0.47, f1_score = 0.58
Round 7
Querying
[4328 4346 4343 4341 4340 4339 4338 4337 4334 4333 4331 4330 4347 4327
 4326]
Updating without sam
Reset and train


 57%|██████████████████████████████▎                      | 20/35 [04:11<03:02, 12.15s/it, loss=0.3]

In [None]:
# from pathlib import Path
# Path("./my/directory").mkdir(parents=True, exist_ok=True)

import os



In [None]:
os.getcwd()

'/root/Master_Thesis/scripts/notebooks'

In [None]:
# for dirname, _, filenames in os.walk(expirements_path):
#     filename = "expirement_{}.json".format(len(filenames))
#     file_path = os.path.join(dirname, filename)
#     with open(file_path, 'w') as f:
#         json.dump(params, f)
#         print(filename)

# UNet

In [None]:
# strategy.net = net_1

In [None]:
# model_1 = torch.load('model_1.pt')
# strategy.net.net.load_state_dict(model_1)
# strategy.net.clf = strategy.net.net.to(torch.device("cuda:1"))
# mask_1 = strategy.predict(strategy.dataset.handler([strategy.dataset.df["images"][idx]], [strategy.dataset.df["masks"][idx]]))[0]
# mask_1 = (mask_1.squeeze().cpu().sigmoid()> 0.5).float()

In [None]:
# model_2 = torch.load('model_2.pt')
# strategy.net.net.load_state_dict(model_2)
# strategy.net.clf = strategy.net.net.to(torch.device("cuda:1"))
# mask_2 = strategy.predict(strategy.dataset.handler([strategy.dataset.df["images"][idx]], [strategy.dataset.df["masks"][idx]]))[0]
# mask_2 = (mask_2.squeeze().cpu().sigmoid()> 0.5).float()

In [None]:
# model_3 = torch.load('model_3.pt')
# strategy.net.net.load_state_dict(model_3)
# strategy.net.clf = strategy.net.net.to(torch.device("cuda:1"))
# mask_3 = strategy.predict(strategy.dataset.handler([strategy.dataset.df["images"][idx]], [strategy.dataset.df["masks"][idx]]))[0]
# mask_3 = (mask_3.squeeze().cpu().sigmoid()> 0.5).float()

# FPN

In [None]:
# strategy.net = net_2

In [None]:
# model_4 = torch.load('model_4.pt')
# strategy.net.net.load_state_dict(model_4)
# strategy.net.clf = strategy.net.net.to(torch.device("cuda:1"))
# mask_4 = strategy.predict(strategy.dataset.handler([strategy.dataset.df["images"][idx]], [strategy.dataset.df["masks"][idx]]))[0]
# mask_4 = (mask_4.squeeze().cpu().sigmoid()> 0.5).float()

In [None]:
# def predict(model_states:list, idx):
#     masks = []
#     for state in model_states:
#         model = torch.load(state)
#         strategy.net.net.load_state_dict(model)
#         strategy.net.clf = strategy.net.net.to(torch.device("cuda:1"))
#         mask = strategy.predict(strategy.dataset.handler([strategy.dataset.df["images"][idx]], [strategy.dataset.df["masks"][idx]]))[0]
#         mask = (mask.squeeze().cpu().sigmoid()> 0.5).float()
#         masks.append(mask)
#     return masks

In [None]:
# masks = predict(["model_1.pt", "model_2.pt", "model_3.pt", "model_4.pt"], 2197)

In [None]:
# import supervision as sv
# imgs = [mask_1, mask_2, mask_3, mask_4]
# sv.plot_images_grid(
#     images=imgs,
#     grid_size=(1, len(imgs)),
#     # titles=['mask_1', 'mask_2', "mask_3","mask_4"]
# )