In [None]:
from kaggle_secrets import UserSecretsClient # see https://www.kaggle.com/discussions/product-feedback/114053 for more info
import sys
import os 
import torch 


user_secrets = UserSecretsClient()
personal_token = user_secrets.get_secret("git-pat")

In [None]:
# !git clone https://{personal_token}@github.com/JulianRodd/MLiP_group_10_task1_HMS.git # for generic 
branch = "bug-hunt"
!git clone -b {branch} https://{personal_token}@github.com/JulianRodd/MLiP_group_10_task1_HMS.git # for branch
    
sys.path.insert(1, "/kaggle/working/MLiP_group_10_task1_HMS") # pos 1 to avoid conflicts

In [None]:
os.chdir('/kaggle/working/MLiP_group_10_task1_HMS')
!git pull

In [None]:
from generics import Paths 


for attr in dir(Paths):
    if attr.startswith("__"):
            continue
    
    path = getattr(Paths, attr)
    print(f"{attr}: {path}")
    
    if path.startswith("./"):
        path = path[1:]

    path = path.replace("/data", "")
    
    if not path.startswith('/kaggle/input/'):
        path = '/kaggle/working/'+path
    path= path.replace('//', '/')
        
    if not os.path.exists(path):
        os.makedirs(path)
    
    setattr(Paths, attr, path)
    path = getattr(Paths, attr)
    print(f"{attr}: {path}")
        


In [None]:
# from datasets.data_loader import CustomDataset
# from utils.loader_utils import load_main_dfs
# from datasets.data_loader_configs import BaseLarge


# data_loader_config = BaseLarge
# train_df, val_df, test_df = load_main_dfs(data_loader_config, train_val_split=(0.8, 0.2))

    
# #     Load datasets
# train_dataset = CustomDataset(config=data_loader_config, main_df = train_df, mode="train", cache=True, augment=True)
# val_dataset = CustomDataset(config=data_loader_config,main_df = val_df, mode="val", cache=True, augment=False)
    
# loader = train_dataset.get_torch_data_loader()

# i, (x, y) = next(enumerate(loader))
# x.shape

In [None]:
from datasets.data_loader import CustomDataset
from datasets.data_loader_configs import BaseFinetuning, BasePretraining, BaseDataConfig, BaseLarge
from generics import Paths
from models.CustomModel import CustomModel
from models.custom_model_configs import ResNetBase_LargeCF, BaseModelConfig
from utils.general_utils import get_logger
# from utils.inference_utils import perform_inference
from utils.loader_utils import load_main_dfs
from utils.training_utils import train


class EffNetControl(BaseModelConfig):
    GRADIENT_ACCUMULATION_STEPS = 1
    MODEL = 'tf_efficientnet_b0'
    FREEZE = False
    EPOCHS = 4
    LARGE_CLASSIFIER = False 
    WEIGHT_DECAY = 0.01
    AMP = True 
    MAX_GRAD_NORM = 1e7
    
    
class ResNetBase_LargeCF(BaseModelConfig):
    GRADIENT_ACCUMULATION_STEPS = 1
    MODEL = 'resnet50'
    FREEZE = False
    EPOCHS = 5
    LARGE_CLASSIFIER = True
    WEIGHT_DECAY = 0.01


def main_train(model_config, fine_tune=False, weights=None, data_loader_config=None):
    logger = get_logger("main")
    
    if fine_tune and data_loader_config is None:
        data_loader_config = BaseFinetuning
    elif data_loader_config is None: 
        data_loader_config = BasePretraining
            
    logger.info(f"Training model {model_config.NAME} with data loader {data_loader_config.NAME}")
    
    train_df, val_df, test_df = load_main_dfs(data_loader_config, train_val_split=(0.8, 0.2))
    
    
#     Load datasets
    train_dataset = CustomDataset(config=data_loader_config, main_df = train_df, mode="train", cache=True, augment=True)
    val_dataset = CustomDataset(config=data_loader_config,main_df = val_df, mode="val", cache=True, augment=False)
    

#     Print summaries
    train_dataset.print_summary()
    val_dataset.print_summary()

    
    # Initialize and train the model
    
    model = CustomModel(model_config)
    if weights is not None:
        model.load_state_dict(weights)
    
    %load_ext tensorboard
    %tensorboard --logdir ../logs
    train(model=model, train_dataset=train_dataset, val_dataset=val_dataset, tensorboard_prefix="effnet_check")


In [None]:
BaseLarge.SUBSET_SAMPLE_COUNT = 0
print(BaseLarge.SUBSET_SAMPLE_COUNT)

main_train(data_loader_config=BaseLarge, model_config=EffNetControl)#

In [None]:
# weights = torch.load('../checkpoints/best_models/best_resnet50_ResNetBase_LargeCF_BasePretraining.pth')
# main_train(fine_tune=True, weights=weights) # maybe start with lower LR here? 

In [None]:
os.listdir('/kaggle/working/tensorboard/inference')

In [None]:
best = os.listdir('../checkpoints/best_models')
best

In [None]:
from IPython.display import FileLink

FileLink(f'../checkpoints/best_models/{best[0]}')



# for f in os.listdir('/kaggle/working/checkpoints/other_models'):
#     print(f)
#     FileLink(fr'{f}')
    