In [None]:
import os
import torch
import yaml
from easydict import EasyDict as edict
from datetime import datetime

import ctools
from data.multi_view_data_injector import MultiViewDataInjector
from data.transforms import get_simclr_data_transforms
from models.mlp_head import MLPHead
from models.base_network import EfficientNet, ResNet
from trainer import BYOLTrainer
from data.reader import loader

# print(torch.__version__)
torch.manual_seed(0)

In [None]:
# Load configuration file
config = edict(yaml.load(open("./config/config.yaml", "r"), Loader=yaml.FullLoader))

data = config.data
save = config.save
batch_size = config.trainer.batch_size

In [None]:
# Set device for computation
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training with: {device}")

# Get data transformations and inject them for multi-view data
data_transform = get_simclr_data_transforms(**config['data_transforms'])
transform = MultiViewDataInjector([data_transform, data_transform])

In [None]:
# Load the dataset based on configuration
if data.isFolder:
    data, _ = ctools.readfolder(data)

train_loader = loader(data, batch_size, transform, shuffle=True, num_workers=2)

In [None]:
# Initialize the online network based on the model defined in config
if "resnet" in config['network']['name']:
    online_network = ResNet(**config['network']).to(device)
elif "efficientnet" in config['network']['name']:
    online_network = EfficientNet(**config['network']).to(device)
else:
    raise ValueError(f"Model {config['network']['name']} not available.")

In [None]:
# Load pre-trained model if a path is provided in the config
pretrained_path = config['network']['pretrain']

if pretrained_path:
    try:
        load_params = torch.load(pretrained_path, map_location=torch.device(torch.device(device)))
        online_network.load_state_dict(load_params['online_network_state_dict'])
    except FileNotFoundError:
        print("Pre-trained weights not found. Training from scratch.")

In [None]:
# Initialize the predictor network
predictor = MLPHead(in_channels=online_network.projection.net[-1].out_features,
                    **config['network']['projection_head']).to(device)

In [None]:
# Initialize the target network, which mirrors the online network
if "resnet" in config['network']['name']:
    target_network = ResNet(**config['network']).to(device)
elif "efficientnet" in config['network']['name']:
    target_network = EfficientNet(**config['network']).to(device)
else:
    raise ValueError(f"Model {config['network']['name']} not available.")

In [None]:
# Initialize the optimizer
optimizer = torch.optim.SGD(list(online_network.parameters()) + list(predictor.parameters()),
                            **config['optimizer']['params'])

In [None]:
# Set up logging directory for saving model checkpoints and logs
log_dir = os.path.join(save.metapath, data.name, config['network']['name'])
log_dir = os.path.join(log_dir, datetime.now().strftime("%Y%m%d-%H%M%S"))

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [None]:
# Initialize the BYOL trainer with the online and target networks, predictor, and optimizer
trainer = BYOLTrainer(online_network=online_network,
                      target_network=target_network,
                      optimizer=optimizer,
                      predictor=predictor,
                      device=device,
                      log_dir=log_dir,
                      **config['trainer'])

# Start training
trainer.train(train_loader)