In [None]:
cd ../../msg

In [None]:
import copy

import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from xt_training import metrics
from xt_training.utils import DummyOptimizer, SKDataset, SKInterface, functional

from utils import transforms as xt_transforms

### Transforms

In [None]:
preprocessing_transforms = transforms.Compose([
    RandomHorizontalFlip(p=0.5)
])

train_transforms = transforms.Compose([
    preprocessing_transforms,
    np.ravel,
    transforms.ToTensor(),
])

val_transforms = train_transforms 

### Dataset

In [None]:
dataset = BlackTuskDataset(
    root=[
        base_dir + 'location-2/train-matrix-v16.van.medium_random',
    ],
    transform=train_transforms,
    target_dict=target_dict,
    path_exclude=path_exclude,
)

train_inds, val_inds = split_samples(dataset.filepaths, 0.95)
train_dataset = Subset(dataset, train_inds)
val_dataset = Subset(copy.deepcopy(dataset), val_inds)
val_dataset.dataset.transform = val_transforms

train_dataset = SKDataset(train_dataset)
val_dataset =  SKDataset(val_dataset)

test_datasets = {
    'test': SKDataset(BlackTuskDataset(
        root=[
            base_dir + 'location-1/train-matrix-v16.van.medium_random',
        ],
        transform=train_transforms,
        target_dict=target_dict,
        path_exclude=path_exclude,
    ))
}

### Dataloaders

In [None]:
batch_size = 1000000
num_workers = 4
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    num_workers=num_workers
)

test_loaders = {}
for ds_name, ds in test_datasets.items():
    test_loaders[ds_name] = DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=num_workers,
    )

### Model & Optimizer

In [None]:
model = SKInterface(
    base_model=RandomForestRegressor(),
    output_dim=1
)

optimizer = DummyOptimizer()

### Model Training

In [None]:
save_dir = 'checkpoints/v2/blacktusk_demo/'

stats, matrix = functional.train(
    save_dir=save_dir,
    train_loader=train_loader,
    model=model,
    optimizer=optimizer,
    epochs=epochs,
    loss_fn=loss_fn,
    overwrite=True,
    val_loader=val_loader,
    test_loaders=test_loaders,
    eval_metrics=eval_metrics,
    on_exit=default_train_exit
)