In [1]:
import sys
sys.path.append('../src')
from models import *
from strategies import *
from custom_datasets import *
import numpy as np
np.random.seed(0)
import tqdm

import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import os
import torch
torch.cuda.empty_cache()
import torch.nn as nn

from PIL import Image
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
import pandas as pd

import time
import json

import threading
from unet_model import *



In [2]:
main_path = "/root/Master_Thesis/"
dataframes_path = main_path + "data/dataframes/"
sam_path = main_path + "sam/sam_vit_h_4b8939.pth"
notebooks_path = main_path + "scripts/notebooks/"
expirements_path = main_path+"expirements/"

In [3]:
# df_name = "brain_df"
# train_df = pd.read_csv(dataframes_path+"brain_df_train.csv")
# test_df = pd.read_csv(dataframes_path+"brain_df_test.csv")

df_name = "lung_tumor_df"
train_df = pd.read_csv(dataframes_path+"lung_df_train.csv")
test_df = pd.read_csv(dataframes_path+"lung_df_test.csv")

# df_name = "flood_df"
# train_df = pd.read_csv(dataframes_path+"flood_df_train.csv")
# test_df = pd.read_csv(dataframes_path+"flood_df_test.csv")

In [4]:
params = {'n_epoch': 35,
          'train_args':{'batch_size': 4, 'num_workers': 1},
          'test_args':{'batch_size': 256, 'num_workers': 1},
          'optimizer_args':{'lr': 5e-3, 'momentum': 0.9},
          'use_sam': False,
          'use_predictor': False,
          'use_generator': False,
          'init_set_size': len(train_df), #400,
          'rounds': 30,
          "activate_sam_at_round":1, 
          "img_size":(128, 128),
          "voting" : False,
          "pre_trained": False,
          "dataset": "Lung_Tumor_Segmentation"}

if params["init_set_size"] == len(train_df):
    params["training_type"] = "no_active"
elif not params["use_sam"]:
    params["training_type"] = "no_sam"
elif params["voting"]:
    params["training_type"] = "voters"
else:
    params["training_type"]="withSAM_NoVoting"

if params["training_type"] == "no_active":
    if params["pre_trained"]:
        params["model_path"] = f'{notebooks_path}trained_models/{params["dataset"]}/no_active/pre_trained_Unet_{params["img_size"][0]}.pt'
    else:
        params["model_path"] = f'{notebooks_path}trained_models/{params["dataset"]}/no_active/not_pre_trained_Unet_{params["img_size"][0]}.pt'
else:
    params["model_path"] = f'{notebooks_path}trained_models/{params["dataset"]}/{params["training_type"]}/{params["init_set_size"]}'
    
params['test_set_size'] = len(test_df)
params['df'] = df_name
params['query_num'] = int(0.05 * params['init_set_size'])
if params['query_num'] == 0:
    params['query_num'] = 1
params["strategy"] = "MarginSampling"

if params["training_type"] == "voters":
    params["model_path"] = f'{params["model_path"]}/voters_{params["img_size"][0]}'

In [5]:
def get_data(handler, train_df, test_df):
    return Data(train_df["images"].to_list(), train_df["masks"].to_list(), test_df["images"].to_list(), test_df["masks"].to_list(), handler, img_size=params["img_size"], df=train_df, path= main_path+"/data/processed/", use_sam=params['use_sam'])

In [6]:
data = get_data(Handler, train_df, test_df)
data.initialize_labels(params["init_set_size"])
results=[]

In [7]:
# for i in [1,2,3,4,5,6,7,8,9,10]:
#     model = smp.create_model(
#             'Unet', encoder_name='resnet34', in_channels=3, classes = 1
#         )
#     torch.save(model.state_dict(), f"trained_models/voters/voters_128_0/model_{i}.pt")
#     print(f"Model_{i}'s training saved!")

In [8]:

# def ensemble(models_num, starting_index, params, data, cuurent_round=1, query_idxs=None):
#     for i in range(starting_index, models_num+starting_index):
#         print(f"Model_{i}'s training started!", flush=True)
#         model = smp.create_model(
#                 'Unet', encoder_name='resnet34', in_channels=3, classes = 1
#             )
#         init_state_Unet = torch.load(f"trained_models/voters/voters_128_0/model_{i}.pt")
#         net = Net(model, params, device = torch.device("cuda"))
#         net.net.load_state_dict(init_state_Unet)
#         strategy = MarginSampling(dataset=data, net=net, sam=None, params=params)
#         if not query_idxs is None:
#             strategy.update(query_idxs)
        
#         strategy.train()
#         torch.save( strategy.net.net.state_dict(), f'{params["voters"]}{cuurent_round}/model_{i}.pt')
#         logits, mask_gt = strategy.predict(data.get_test_data())
#         iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
#         print(f"Testing metrics for model_{i}: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}", flush=True)
#         print(f"Model_{i}'s saved!", flush=True)
#     # print("Done!")
    

In [9]:
# query_idxs = [1656,  121,  253,  968, 2095]

In [10]:
# ensemble(models_num=1, starting_index=3, params=params, data=data, cuurent_round=1, query_idxs=query_idxs)

In [11]:
# # for i in range(1, 11):
# for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
#     t = threading.Thread(target=ensemble, daemon=True, args=[1, i, params, data, 1, query_idxs])
#     t.start()

In [12]:
# t1 = threading.Thread(target=ensemble, daemon=True, args=[5, 1, params, data, query_idxs])
# t1.start()

In [13]:
# t2 = threading.Thread(target=ensemble, daemon=True, args=[5, 6, params, data, query_idxs])
# t2.start()

In [14]:
if params["pre_trained"]:
    model = smp.create_model('Unet', encoder_name='resnet34', in_channels=3, classes = 1)
else:
    model = UNet(n_channels=3, n_classes=1, bilinear=True)
net = Net(model, params, device = torch.device("cuda"))

In [15]:
params["model_path"]

'/root/Master_Thesis/scripts/notebooks/trained_models/Lung_Tumor_Segmentation/no_active/not_pre_trained_Unet_128.pt'

In [16]:
wandb.init(
    # set the wandb project where this run will be logged
    project=params["dataset"],
    
    notes = f'{params["training_type"]}_{params["init_set_size"]}',

    # track hyperparameters and run metadata
    config=params    
)

net.net.load_state_dict(torch.load(params["model_path"]))
logits, mask_gt = net.predict(data.get_test_data())
iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
for i in range(30):
    wandb.log({"iou_score" : iou_score, "accuracy" : accuracy, "precision" : precision, "recall" : recall, "f1_score" : f1_score})
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msaleemfares1995-sf[0m ([33mthesis_fares[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='0.005 MB of 0.010 MB uploaded\r'), FloatProgress(value=0.526578871479381, max=1.0)…

0,1
accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1_score,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
iou_score,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
precision,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
recall,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.99972
f1_score,0.95949
iou_score,0.92213
precision,0.97346
recall,0.94591


In [15]:
def test(net, idx, params, data, current_round=1):
    for i in idx:
        net.net.load_state_dict(torch.load(f'{params["voters"]}{current_round}/model_{i}.pt'))
        logits, mask_gt = net.predict(data.get_test_data())
        iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
        print((f"Testing metrics for model_{i}: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}"))
    

In [16]:
idx = [i for i in range(1,10)]
idx

[1, 2, 3, 4, 5, 6, 7, 8, 9]

In [28]:
# test(net, idx, params, data, current_round=0)

In [48]:
# test(net, idx, params, data, current_round=1)

In [47]:
# test(net, idx, params, data, current_round=2)

In [46]:
# test(net, idx, params, data, current_round=3)

In [45]:
# test(net, idx, params, data, current_round=4)

In [44]:
# test(net, idx, params, data, current_round=5)

In [43]:
# test(net, idx, params, data, current_round=6)

In [42]:
# test(net, idx, params, data, current_round=7)

In [41]:
# test(net, idx, params, data, current_round=8)

In [40]:
# test(net, idx, params, data, current_round=9)

In [39]:
# test(net, idx, params, data, current_round=10)

In [38]:
# test(net, idx, params, data, current_round=11)

In [37]:
# test(net, idx, params, data, current_round=12)

In [36]:
# test(net, idx, params, data, current_round=13)

In [35]:
# test(net, idx, params, data, current_round=14)

In [34]:
# test(net, idx, params, data, current_round=15)

In [33]:
# test(net, idx, params, data, current_round=16)

In [32]:
# test(net, idx, params, data, current_round=17)

In [31]:
# test(net, idx, params, data, current_round=18)

In [30]:
# test(net, idx, params, data, current_round=19)

In [29]:
# test(net, idx, params, data, current_round=20)