In [1]:
import sys
sys.path.append('../src')
from models import *
from strategies import *
from custom_datasets import *
import numpy as np
np.random.seed(0)
import tqdm

import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import os
import torch
torch.cuda.empty_cache()
import torch.nn as nn

from PIL import Image
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
import pandas as pd

import time
import json
import wandb



In [2]:
main_path = "/root/Master_Thesis/"
dataframes_path = main_path + "data/dataframes/"
sam_path = main_path + "sam/sam_vit_h_4b8939.pth"
voters_path = main_path + "scripts/notebooks/trained_models/voters/"
expirements_path = main_path+"expirements/"

In [3]:
df_name = "brain_df"
train_df = pd.read_csv(dataframes_path+"brain_df_train.csv")
test_df = pd.read_csv(dataframes_path+"brain_df_test.csv")

In [4]:
with open('params.json') as f:
    params = json.load(f)
    print(params)

{'n_epoch': 35, 'train_args': {'batch_size': 4, 'num_workers': 1}, 'test_args': {'batch_size': 256, 'num_workers': 1}, 'optimizer_args': {'lr': 0.005, 'momentum': 0.9}, 'use_sam': False, 'use_predictor': False, 'use_generator': False, 'init_set_size': 400, 'rounds': 23, 'activate_sam_at_round': 1, 'img_size': [128, 128]}


In [5]:
params = {'n_epoch': 35,
          'train_args':{'batch_size': 4, 'num_workers': 1},
          'test_args':{'batch_size': 256, 'num_workers': 1},
          'optimizer_args':{'lr': 5e-3, 'momentum': 0.9},
          'use_sam': True,
          'use_predictor': True,
          'use_generator': False,
          'init_set_size': 400,
          'rounds': 23,
          "activate_sam_at_round":1, 
          "img_size":(128, 128)}

params['test_set_size'] = len(test_df)
params['df'] = df_name
params['query_num'] = int(0.05 * params['init_set_size'])
params["strategy"] = "MarginSampling"
params["voters"] = f'trained_models/voters/voters_{params["img_size"][0]}_'
init_main_state_path = f'trained_models/voters/voters_{params["img_size"][0]}_0'

In [6]:
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="baseline",
#     resume=True
# )

In [7]:
wandb.init(
    # set the wandb project where this run will be logged
    project="baseline",

    # track hyperparameters and run metadata
    config=params    
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33msaleemfares1995-sf[0m ([33mthesis_fares[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
init_main_state_path

'trained_models/voters/voters_128_0'

In [9]:
if params['use_sam']:
    sam = SAMOracle(checkpoint_path=sam_path, img_size=params["img_size"])
else:
    sam =None

In [10]:
# from unet_model import *

# model = UNet(n_channels=3, n_classes=1, bilinear=True)

# init_path = 'trained_models/not_pre_trained_Unet_0.pt'
# if not os.path.isfile(init_path):
#     torch.save(model.state_dict(), init_path)
    
# init_state_Unet = torch.load(init_path)

In [11]:
model = smp.create_model(
            'Unet', encoder_name='resnet34', in_channels=3, classes = 1
        )
init_path = voters_path + 'voters_128_0/main_Unet.pt'
if not os.path.isfile(init_path):
    if not os.path.exists(voters_path + 'voters_128_0/'):
        os.makedirs(voters_path + 'voters_128_0/')
    torch.save(model.state_dict(), init_path)
init_state_Unet = torch.load(init_path)

# torch.save(init_state_Unet, params["voters"]+'/main_Unet.pt')

net = Net(model, params, device = torch.device("cuda"))

In [12]:
def get_data(handler, train_df, test_df):
    return Data(train_df["images"].to_list(), train_df["masks"].to_list(), test_df["images"].to_list(), test_df["masks"].to_list(), handler, img_size=params["img_size"], df=train_df, path= main_path+"/data/processed/", use_sam=params['use_sam'])

In [13]:
data = get_data(Handler, train_df, test_df)
data.initialize_labels(params["init_set_size"])

In [14]:
strategy = MarginSampling(dataset=data, net=net, sam=sam, params=params)
# strategy.net.net.load_state_dict(init_state_Unet)
params["strategy"] = "MarginSampling"

In [15]:
logs =[]
print("Round 0")
main_1_path = voters_path + 'voters_128_1/main_Unet.pt'
# main_1_path = 'trained_models/No_Active_128_Unet.pt'
# main_1_path = 'trained_models/not_pre_trained_Unet_128_done.pt'
# main_1_path = f'trained_models/no_sam/Active_{params["init_set_size"]}_no_sam_128_Unet.pt'#200, 300, 400, 500
if not os.path.isfile(main_1_path):
    strategy.train()
    if not os.path.exists(voters_path + 'voters_128_1/'):
        os.makedirs(voters_path + 'voters_128_1/')
    torch.save(strategy.net.net.state_dict(), main_1_path)
else:
    strategy.net.net.load_state_dict(torch.load(main_1_path))
    
logits, mask_gt = strategy.predict(data.get_test_data())
iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
wandb.log({"iou_score" : iou_score, "accuracy" : accuracy, "precision" : precision, "recall" : recall, "f1_score" : f1_score})
logs.append(f"Round 0 testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
print(logs[0])

for rd in range(1, params["rounds"]):
    print(f"Round {rd}")

    # query
    print("Querying")
    query_idxs = strategy.query(params["query_num"])
    print(query_idxs)
    # update labels
    if params["use_sam"] and rd >= params["activate_sam_at_round"]:
        print("Updating with sam")
        masks = strategy.update_voting(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"], round=rd)
        # masks = strategy.update_weighted_voting(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"], round=rd)
        # masks = strategy.update(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"], round=rd)
    else:
        print("Updating without sam")
        strategy.update(query_idxs)
    
    print("Reset and train")
    
    # rd_main_path = f'trained_models/no_sam/Active_{params["init_set_size"]}_{rd}_no_sam_128_Unet.pt'#200, 300, 400, 500
    # if not os.path.isfile(rd_main_path):
    #     strategy.net.net.load_state_dict(init_state_Unet)
    #     strategy.train()
    #     torch.save(strategy.net.net.state_dict(), rd_main_path)
    # else:
    #     strategy.net.net.load_state_dict(torch.load(rd_main_path))
    
    strategy.net.net.load_state_dict(init_state_Unet)
    strategy.train()
    main_path = f'trained_models/voters/voters_128_{rd}/'#200, 300, 400, 500
    if not os.path.exists(main_path):
        os.makedirs(main_path)
    torch.save(net.net.state_dict(), main_path + 'main_Unet.pt')
    
    # calculate accuracy
    logits, maks_gt = strategy.predict(data.get_test_data())
    iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
    wandb.log({"iou_score" : iou_score, "accuracy" : accuracy, "precision" : precision, "recall" : recall, "f1_score" : f1_score})
    logs.append(f"Round {rd} testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
    print(logs[rd])
    
params['logs'] = logs
wandb.finish()

Round 0
Round 0 testing metrics: iou_score = 0.70, accuracy = 1.00, precision = 0.84, recall = 0.80, f1_score = 0.82
Round 1
Querying
[ 963 1900 1899 1898 1897 1894  955 1893 1889 1888  961 1887  949 1886
 1883 1882  968 1880 1877 1873]
Updating with sam
Training model_1 for voting


 29%|██████████████▌                                    | 10/35 [01:54<04:46, 11.47s/it, loss=0.169]

In [None]:
# from pathlib import Path
# Path("./my/directory").mkdir(parents=True, exist_ok=True)

import os



In [None]:
os.getcwd()

'/root/Master_Thesis/scripts/notebooks'

In [None]:
# for dirname, _, filenames in os.walk(expirements_path):
#     filename = "expirement_{}.json".format(len(filenames))
#     file_path = os.path.join(dirname, filename)
#     with open(file_path, 'w') as f:
#         json.dump(params, f)
#         print(filename)

# UNet

In [None]:
# strategy.net = net_1

In [None]:
# model_1 = torch.load('model_1.pt')
# strategy.net.net.load_state_dict(model_1)
# strategy.net.clf = strategy.net.net.to(torch.device("cuda:1"))
# mask_1 = strategy.predict(strategy.dataset.handler([strategy.dataset.df["images"][idx]], [strategy.dataset.df["masks"][idx]]))[0]
# mask_1 = (mask_1.squeeze().cpu().sigmoid()> 0.5).float()

In [None]:
# model_2 = torch.load('model_2.pt')
# strategy.net.net.load_state_dict(model_2)
# strategy.net.clf = strategy.net.net.to(torch.device("cuda:1"))
# mask_2 = strategy.predict(strategy.dataset.handler([strategy.dataset.df["images"][idx]], [strategy.dataset.df["masks"][idx]]))[0]
# mask_2 = (mask_2.squeeze().cpu().sigmoid()> 0.5).float()

In [None]:
# model_3 = torch.load('model_3.pt')
# strategy.net.net.load_state_dict(model_3)
# strategy.net.clf = strategy.net.net.to(torch.device("cuda:1"))
# mask_3 = strategy.predict(strategy.dataset.handler([strategy.dataset.df["images"][idx]], [strategy.dataset.df["masks"][idx]]))[0]
# mask_3 = (mask_3.squeeze().cpu().sigmoid()> 0.5).float()

# FPN

In [None]:
# strategy.net = net_2

In [None]:
# model_4 = torch.load('model_4.pt')
# strategy.net.net.load_state_dict(model_4)
# strategy.net.clf = strategy.net.net.to(torch.device("cuda:1"))
# mask_4 = strategy.predict(strategy.dataset.handler([strategy.dataset.df["images"][idx]], [strategy.dataset.df["masks"][idx]]))[0]
# mask_4 = (mask_4.squeeze().cpu().sigmoid()> 0.5).float()

In [None]:
# def predict(model_states:list, idx):
#     masks = []
#     for state in model_states:
#         model = torch.load(state)
#         strategy.net.net.load_state_dict(model)
#         strategy.net.clf = strategy.net.net.to(torch.device("cuda:1"))
#         mask = strategy.predict(strategy.dataset.handler([strategy.dataset.df["images"][idx]], [strategy.dataset.df["masks"][idx]]))[0]
#         mask = (mask.squeeze().cpu().sigmoid()> 0.5).float()
#         masks.append(mask)
#     return masks

In [None]:
# masks = predict(["model_1.pt", "model_2.pt", "model_3.pt", "model_4.pt"], 2197)

In [None]:
# import supervision as sv
# imgs = [mask_1, mask_2, mask_3, mask_4]
# sv.plot_images_grid(
#     images=imgs,
#     grid_size=(1, len(imgs)),
#     # titles=['mask_1', 'mask_2', "mask_3","mask_4"]
# )