In [1]:
import sys
sys.path.append('../src')
from models import *
from strategies import *
from custom_datasets import *
import numpy as np
np.random.seed(0)
import tqdm

import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import os
import torch
torch.cuda.empty_cache()
import torch.nn as nn

from PIL import Image
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
import pandas as pd

import time
import json



In [2]:
main_path = "/root/Master_Thesis/"
dataframes_path = main_path + "data/dataframes/"
sam_path = main_path + "sam/sam_vit_h_4b8939.pth"
expirements_path = main_path+"expirements/"

In [3]:
# df_name = "brain_df"
# train_df = pd.read_csv(dataframes_path+"brain_df_train.csv")
# test_df = pd.read_csv(dataframes_path+"brain_df_test.csv")

## Too misclassifications
# df_name = "fire_df"
# train_df = pd.read_csv(dataframes_path+"fire_df_train.csv")
# test_df = pd.read_csv(dataframes_path+"fire_df_test.csv")

# Couldn't learn from it
# df_name = "aerial_df"
# train_df = pd.read_csv(dataframes_path+"aerial_df_train.csv")
# test_df = pd.read_csv(dataframes_path+"aerial_df_test.csv")

df_name = "lung_df"
train_df = pd.read_csv(dataframes_path+"lung_df_train.csv")
test_df = pd.read_csv(dataframes_path+"lung_df_test.csv")


In [4]:
len(test_df)

4731

In [5]:
params = {'n_epoch': 20,
          'train_args':{'batch_size': 128, 'num_workers': 1},
          'test_args':{'batch_size': 256, 'num_workers': 1},
          'optimizer_args':{'lr': 0.00005, 'momentum': 0.9},
          'use_sam': False,
          'use_predictor': False,
          'use_generator': False,
          'init_set_size': len(train_df),#20,
          'query_num': 4, #int(0.1*len(test_df)),
          'rounds': 1,
          "activate_sam_at_round":4, 
          'test_set_size': len(test_df),
          'df': df_name}
print(params)

{'n_epoch': 20, 'train_args': {'batch_size': 128, 'num_workers': 1}, 'test_args': {'batch_size': 256, 'num_workers': 1}, 'optimizer_args': {'lr': 0.005, 'momentum': 0.9}, 'use_sam': False, 'use_predictor': False, 'use_generator': False, 'init_set_size': 11036, 'query_num': 4, 'rounds': 1, 'activate_sam_at_round': 4, 'test_set_size': 4731, 'df': 'lung_df'}


In [6]:
if params['use_sam']:
    sam = SAMOracle(checkpoint_path=sam_path)
else:
    sam =None

In [7]:
model = smp.create_model(
            'FPN', encoder_name='resnet34', in_channels=3, classes = 1
        )
torch.save(model.state_dict(), 'init_state.pt')
init_state = torch.load('init_state.pt')
net = Net(model, params, device = torch.device("cuda:1"))

In [8]:
def get_data(handler, train_df, test_df):
    # raw_train = AL_Seg_dataset(main_path + "/data/processed/oracle/", inp_df=train_df, init=True, transform=True, use_sam=params['use_sam'])
    # raw_test = AL_Seg_dataset(main_path + "/data/processed/oracle/", inp_df=test_df, init=True, transform=True, use_sam=params['use_sam'])
    # df = raw_train.df
    return Data(train_df["images"].to_list(), train_df["masks"].to_list(), test_df["images"].to_list(), test_df["masks"].to_list(), handler, df=train_df, path= main_path+"/data/processed/", use_sam=params['use_sam'])


In [9]:
data = get_data(Handler, train_df, test_df)
data.initialize_labels(params["init_set_size"])

### Choose an AL strategy from a)RandomSampling b)MarginSampling c)EntropySampling d)KCenterGreedy e)AdversarialBIM

In [10]:
strategy = MarginSampling(dataset=data, net=net, sam=sam)
strategy.net.net.load_state_dict(init_state)
params["strategy"] = "MarginSampling"

In [11]:
torch.cuda.empty_cache()
logs=[]
print("Round 0")
strategy.train()
logits, mask_gt = strategy.predict(data.get_test_data())
iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
logs.append(f"Round 0 testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
print(logs[0])

for rd in range(1, params["rounds"]):
    print(f"Round {rd}")

    # query
    print("Querying")
    query_idxs = strategy.query(params["query_num"])
    print(query_idxs)

    # update labels
    if params["use_sam"] and rd >= params["activate_sam_at_round"]:
        print("Updating with sam")
        strategy.update(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"])
        print("Sam failed to mask: ", strategy.sam_failed)
    else:
        print("Updating without sam")
        strategy.update(query_idxs)
    
    print("Reset and train")
    init_state = torch.load('init_state.pt')
    strategy.net.net.load_state_dict(init_state)
    strategy.train()

    # calculate accuracy
    logits, maks_gt = strategy.predict(data.get_test_data())
    iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
    # logs.append(f"Round {rd} testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}, human_envolved = {strategy.human_envolved}")
    logs.append(f"Round {rd} testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}, sam_failed = {strategy.human_envolved}")
    strategy.human_envolved = 0
    print(logs[rd])
    
params['logs'] = logs

for dirname, _, filenames in os.walk(expirements_path):
    filename = "expirement_{}.json".format(len(filenames))
    file_path = os.path.join(dirname, filename)
    with open(file_path, 'w') as f:
        json.dump(params, f)
        print(filename)

Round 0
11036


 45%|██████████████████████▉                            | 9/20 [40:15<49:11, 268.35s/it, loss=0.984]


KeyboardInterrupt: 

In [None]:
# data = get_data(Handler, train_df, test_df)
# data.initialize_labels(params["init_set_size"])
# strategy = EntropySampling(dataset=data, net=net, sam=sam)
# strategy.net.net.load_state_dict(init_state)
# params["strategy"] = "EntropySampling"

In [None]:
# torch.cuda.empty_cache()
# logs=[]
# print("Round 0")
# strategy.train()
# logits, mask_gt = strategy.predict(data.get_test_data())
# iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
# logs.append(f"Round 0 testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
# print(logs[0])

# for rd in range(1, params["rounds"]):
#     print(f"Round {rd}")

#     # query
#     print("Querying")
#     query_idxs = strategy.query(params["query_num"])
#     print(query_idxs)

#     # update labels
#     if params["use_sam"] and rd >= params["activate_sam_at_round"]:
#         print("Updating with sam")
#         strategy.update(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"])
#     else:
#         print("Updating without sam")
#         strategy.update(query_idxs)
    
#     print("Reset and train")
#     init_state = torch.load('init_state.pt')
#     strategy.net.net.load_state_dict(init_state)
#     strategy.train()

#     # calculate accuracy
#     logits, maks_gt = strategy.predict(data.get_test_data())
#     iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
#     logs.append(f"Round {rd} testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
#     print(logs[rd])
    
# params['logs'] = logs

# for dirname, _, filenames in os.walk(expirements_path):
#     filename = "expirement_{}.json".format(len(filenames))
#     file_path = os.path.join(dirname, filename)
#     with open(file_path, 'w') as f:
#         json.dump(params, f)
#         print(filename)

In [None]:
# data = get_data(Handler, train_df, test_df)
# data.initialize_labels(params["init_set_size"])
# strategy = BALDDropout(dataset=data, net=net, sam=sam)
# strategy.net.net.load_state_dict(init_state)
# params["strategy"] = "BALDDropout"

In [None]:
# torch.cuda.empty_cache()
# logs=[]
# print("Round 0")
# strategy.train()
# logits, mask_gt = strategy.predict(data.get_test_data())
# iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
# logs.append(f"Round 0 testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
# print(logs[0])

# for rd in range(1, params["rounds"]):
#     print(f"Round {rd}")

#     # query
#     print("Querying")
#     query_idxs = strategy.query(params["query_num"])
#     print(query_idxs)

#     # update labels
#     if params["use_sam"] and rd >= params["activate_sam_at_round"]:
#         print("Updating with sam")
#         strategy.update(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"])
#     else:
#         print("Updating without sam")
#         strategy.update(query_idxs)
    
#     print("Reset and train")
#     init_state = torch.load('init_state.pt')
#     strategy.net.net.load_state_dict(init_state)
#     strategy.train()

#     # calculate accuracy
#     logits, maks_gt = strategy.predict(data.get_test_data())
#     iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
#     logs.append(f"Round {rd} testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
#     print(logs[rd])
    
# params['logs'] = logs

# for dirname, _, filenames in os.walk(expirements_path):
#     filename = "expirement_{}.json".format(len(filenames))
#     file_path = os.path.join(dirname, filename)
#     with open(file_path, 'w') as f:
#         json.dump(params, f)
#         print(filename)

In [None]:
# data = get_data(Handler, train_df, test_df)
# data.initialize_labels(params["init_set_size"])
# strategy = AdversarialBIM(dataset=data, net=net, sam=sam)
# strategy.net.net.load_state_dict(init_state)
# params["strategy"] = "AdversarialBIM"

In [None]:
# torch.cuda.empty_cache()
# logs=[]
# print("Round 0")
# strategy.train()
# logits, mask_gt = strategy.predict(data.get_test_data())
# iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
# logs.append(f"Round 0 testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
# print(logs[0])

# for rd in range(1, params["rounds"]):
#     print(f"Round {rd}")

#     # query
#     print("Querying")
#     query_idxs = strategy.query(params["query_num"])
#     print(query_idxs)

#     # update labels
#     if params["use_sam"] and rd >= params["activate_sam_at_round"]:
#         print("Updating with sam")
#         strategy.update(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"])
#     else:
#         print("Updating without sam")
#         strategy.update(query_idxs)
    
#     print("Reset and train")
#     init_state = torch.load('init_state.pt')
#     strategy.net.net.load_state_dict(init_state)
#     strategy.train()

#     # calculate accuracy
#     logits, maks_gt = strategy.predict(data.get_test_data())
#     iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
#     logs.append(f"Round {rd} testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
#     print(logs[rd])
    
# params['logs'] = logs

# for dirname, _, filenames in os.walk(expirements_path):
#     filename = "expirement_{}.json".format(len(filenames))
#     file_path = os.path.join(dirname, filename)
#     with open(file_path, 'w') as f:
#         json.dump(params, f)
#         print(filename)

In [None]:
# data = get_data(Handler, train_df, test_df)
# data.initialize_labels(params["init_set_size"])
# strategy = KCenterGreedy(dataset=data, net=net, sam=sam)
# strategy.net.net.load_state_dict(init_state)
# params["strategy"] = "KCenterGreedy"

In [None]:
# torch.cuda.empty_cache()
# logs=[]
# print("Round 0")
# strategy.train()
# logits, mask_gt = strategy.predict(data.get_test_data())
# iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
# logs.append(f"Round 0 testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
# print(logs[0])

# for rd in range(1, params["rounds"]):
#     print(f"Round {rd}")

#     # query
#     print("Querying")
#     query_idxs = strategy.query(params["query_num"])
#     print(query_idxs)

#     # update labels
#     if params["use_sam"] and rd >= params["activate_sam_at_round"]:
#         print("Updating with sam")
#         strategy.update(query_idxs, start_sam=True, use_predictor=params["use_predictor"], use_generator=params["use_generator"])
#     else:
#         print("Updating without sam")
#         strategy.update(query_idxs)
    
#     print("Reset and train")
#     init_state = torch.load('init_state.pt')
#     strategy.net.net.load_state_dict(init_state)
#     strategy.train()

#     # calculate accuracy
#     logits, maks_gt = strategy.predict(data.get_test_data())
#     iou_score, accuracy, precision, recall, f1_score = data.cal_test_metrics(logits, mask_gt )
#     logs.append(f"Round {rd} testing metrics: iou_score = {iou_score:.2f}, accuracy = {accuracy:.2f}, precision = {precision:.2f}, recall = {recall:.2f}, f1_score = {f1_score:.2f}")
#     print(logs[rd])
    
# params['logs'] = logs

# for dirname, _, filenames in os.walk(expirements_path):
#     filename = "expirement_{}.json".format(len(filenames))
#     file_path = os.path.join(dirname, filename)
#     with open(file_path, 'w') as f:
#         json.dump(params, f)
#         print(filename)