In [1]:
'''
    Training script. Here, we load the training and validation datasets (and
    data loaders) and the model and train and validate the model accordingly.

    2022 Benjamin Kellenberger
'''

import os
import argparse
import yaml
import glob
from tqdm import trange

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import SGD

# let's import our own classes and functions!
from evaluation.util_eval import init_seed
from evaluation.my_custom_dataset_eval import CTDataset ###??
from model_eval import CustomResNet18, CustomResNet50, SimClrPytorchResNet50, PAWSResNet50 
from sklearn.metrics import balanced_accuracy_score
from sklearn.utils import class_weight
import numpy as np
import wandb



def create_dataloader(cfg, split='train'):
    '''
        Loads a dataset according to the provided split and wraps it in a
        PyTorch DataLoader object.
    '''
    #dataset_instance = CTDataset(cfg, split)        # create an object instance of our CTDataset class
    dataset_instance = CTDataset(cfg, split='test_data', transform=False) # remove transform call from 

    device = cfg['device']

    dataLoader = DataLoader(
            dataset=dataset_instance,
            batch_size=cfg['batch_size'], ## is there any drop-last going on? - ben
            shuffle=True,
            num_workers=cfg['num_workers']
        )
    
    ### compute weights for class balancing
    classes_for_weighting = []
    for data, labels in dataLoader:
        classes_for_weighting.extend(list(labels.numpy()))  

    class_weights=class_weight.compute_class_weight('balanced',classes = np.unique(classes_for_weighting),y = np.array(classes_for_weighting))
    class_weights = class_weights/np.sum(class_weights)
    class_weights=torch.tensor(class_weights,dtype=torch.float).to(device)

    return dataLoader, class_weights

ModuleNotFoundError: No module named 'my_custom_dataset_eval'