In [1]:
import time
import argparse
import logging
from tqdm import tqdm
import pandas as pd
from attrdict import AttrDict
from collections import defaultdict
from scipy.stats import gmean
import numpy as np
import random
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from imbalanced_regression.qsm.resnet import resnet50
from imbalanced_regression.qsm import loss
from imbalanced_regression.qsm.datasets import QSM
from imbalanced_regression.utils import *
import util
from util import pyvis
import sklearn.model_selection as skm
import os
import nibabel as nib
from IPython.display import HTML
import datetime
os.environ["KMP_WARNINGS"] = "FALSE"
HTML('''
<style>
.jupyter-matplotlib {
    background-color: #000;
}

.widget-label, .jupyter-matplotlib-header{
    color: #fff;
}

.jupyter-button {
    background-color: #333;
    color: #fff;
}
</style>
''')
%matplotlib widget

In [2]:
# Get case IDs
case_list = open('/home/ali/RadDBS-QSM/data/docs/cases_90','r')
lines = case_list.read()
lists = np.loadtxt(case_list.name,comments="#", delimiter=",",unpack=False,dtype=str)
case_id = []
for lines in lists:     
    case_id.append(lines[-9:-7])

# Load scores
file_dir = '/home/ali/RadDBS-QSM/data/docs/QSM anonymus- 6.22.2023-1528.csv'
motor_df = util.filter_scores(file_dir,'pre-dbs updrs','stim','CORNELL ID')
# Find cases with all required scores
subs,pre_imp,post_imp,pre_updrs_off = util.get_full_cases(motor_df,
                                                          'CORNELL ID',
                                                          'OFF (pre-dbs updrs)',
                                                          'ON (pre-dbs updrs)',
                                                          'OFF meds ON stim 6mo')
# Find overlap between scored subjects and nii
ids = np.asarray(case_id).astype(int)
ids = ids[ids != 54]
cases_idx = np.in1d(subs,ids)
ccases = subs[cases_idx]
per_change = np.round(post_imp[cases_idx],1)

nii_paths = []
seg_nii_paths = []
qsm_dir = '/home/ali/RadDBS-QSM/data/nii/qsm/'
seg_dir = '/home/ali/RadDBS-QSM/data/nii/seg/'
qsm_niis = sorted(os.listdir(qsm_dir))
seg_niis = sorted(os.listdir(seg_dir))
for k in np.arange(len(ccases)):
    for file in qsm_niis:
        if int(ccases[k]) == int(file[18:20]):
            nii_paths.append(qsm_dir+file)
            seg_nii_paths.append(seg_dir+'labels_2iMag'+file[18:20]+'.nii.gz')

train_dir, test_dir, train_seg, test_seg, y_train, y_test = skm.train_test_split(nii_paths, seg_nii_paths, per_change, test_size=0.1, random_state=1)
train_dir, val_dir, train_seg, val_seg, y_train, y_val = skm.train_test_split(train_dir, train_seg, y_train, test_size=0.2, random_state=1)

In [3]:
args = AttrDict()
args.gpu = 0
args.optimizer = 'sgd'
args.lr = 1e-3
args.epoch = 100
args.momentum = 0.9
args.weight_decay = 1e-4
args.schedule = [60,80]
args.print_freq = 10
args.resume = ''
args.pretrained = False
args.evaluate = False
args.loss = 'l1'
args.dataset = 'qsm'
args.model = 'resnet50'
args.store_root = 'checkpoint'
args.store_name = ''
args.data_dir = '/home/ali/RadDBS-QSM/data/qsm/'
args.fds = False
args.fds_kernel = 'gaussian'
args.fds_ks = 3
args.fds_sigma = 1
args.fds_mmt = 0.9
args.start_update = 0
args.start_smooth = 1
args.bucket_num = 5
args.bucket_start = 3
args.start_epoch = 0
args.best_loss = 1e5
args.reweight = 'none'
args.retrain_fc = False
args.lds = False
args.lds_kernel = 'gaussian'
args.lds_ks = 3
args.lds_sigma = 1
args.batch_size = 3
args.workers = 24

In [22]:
import numpy as np
import nibabel as nib
import numpy as np
import torch
import torch.utils.data as data
import torchio as tio
from imbalanced_regression.utils import get_lds_kernel_window
import logging
from scipy.ndimage import convolve1d
from torch.utils import data
import torchio.transforms as transforms
from imbalanced_regression.qsm.utils import pyvis, model_scale, scale_feature_matrix
import sklearn.preprocessing as skp

class QSM(data.Dataset):
    def __init__(self, data_dir, mask_dir, targets, nz, nx, split='train', reweight='none',
                 lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
        self.images_list = [nib.load(image_path) for image_path in data_dir]
        self.masks_list = [nib.load(mask_path) for mask_path in mask_dir]
        self.data_dir = data_dir
        self.mask_dir = mask_dir
        self.targets = targets
        self.nz = nz
        self.nx = nx
        self.split = split
        self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma)

    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, index):
        case_dir = self.data_dir[index]
        #print(case_dir)
        nx = self.nx
        nz = self.nz
        #print('nz is ',nz, ' and nx is ',nx)
        nii_image = self.images_list[index]
        nii_mask = self.masks_list[index]
        data = np.asarray(nii_image.dataobj)
        mask = np.asarray(nii_mask.dataobj)
        #print('Applying mask of shape ',str(mask.shape),' to image of size ',str(data.shape),' for ',case_dir)#,' with size ',str(self.img_size)+' before transform')
        #try:
        img = data[:,:,~(mask==0).all((0,1))]
        img = img[~(mask==0).all((1,2)),:,:]
        img = torch.from_numpy(img[:,~(mask==0).all((0,2)),:])
        #except:
        #    print(case_dir,' mask has shape ',str(mask.shape))
        # self.img_size = img.shape
        # target = self.targets[index]
        # transform = self.get_transform(img,nx,nz)
        # img = torch.squeeze(transform(torch.unsqueeze(img,axis=0)))
        # #print(case_dir+' has size ',str(img.shape)+' after transform')
        # label = target
        # weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)])

        return img, #label, weight

    def get_transform(self,img,nx,nz):
        self.img = img
        self.img_size = (self.img).shape
        if self.img_size[2]>nz:
            self.img = self.img[:,:,(self.img_size[2]//2)-(nz//2):(self.img_size[2]//2)+(nz//2)]
            tpad = transforms.Pad((0,0,0))
        else:
            if (nz-self.img_size[2])/2 == (nz-self.img_size[2])//2:
                tpad = transforms.Pad((0,0,(nz-self.img_size[2])//2))
            else:
                #print('Padding an odd number of slices with ',str((nz-self.img_size[2])//2),' and ',str(((nz-self.img_size[2])//2)+1))                      
                tpad = transforms.Pad((0,0,0,0,
                                    (nz-self.img_size[2])//2,
                                    ((nz-self.img_size[2])//2)+1))
        if self.split == 'train':
            transform = transforms.Compose([
                transforms.Crop((nx,nx,nx,nx,0,0)),
                tpad,
                #transforms.RandomFlip(axes=['LR', 'AP', 'IS']),
                transforms.RescaleIntensity(out_min_max=(-1, 1)),
            ])
        else:
            transform = transforms.Compose([
                transforms.Crop((nx,nx,nx,nx,0,0)),
                tpad,
                transforms.RescaleIntensity(out_min_max=(-1, 1)),
            ])
        return transform

    def _prepare_weights(self, reweight, max_target=1, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
        assert reweight in {'none', 'inverse', 'sqrt_inv'}
        assert reweight != 'none' if lds else True, \
            "Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS"

        value_dict = {x: 0 for x in range(max_target)}
        labels = self.targets
        for label in labels:
            value_dict[min(max_target - 1, int(label))] += 1
        if reweight == 'sqrt_inv':
            value_dict = {k: np.sqrt(v) for k, v in value_dict.items()}
        elif reweight == 'inverse':
            value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()}  # clip weights for inverse re-weight
        num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels]
        if not len(num_per_label) or reweight == 'none':
            return None
        print(f"Using re-weighting: [{reweight.upper()}]")

        if lds:
            lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
            print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})')
            smoothed_value = convolve1d(
                np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant')
            num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels]

        weights = [np.float32(1 / x) for x in num_per_label]
        scaling = len(weights) / np.sum(weights)
        weights = [scaling * x for x in weights]
        return weights


class QSM_features(data.Dataset):
    def __init__(self, data_dir, targets, pre_metric, scaler_ss=None, split='train', reweight='none',
                 lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
        self.feature_matrices = [np.load(feature_matrix) for feature_matrix in data_dir]
        self.data_dir = data_dir
        self.targets = targets
        self.split = split
        self.pre_updrs_off = pre_metric
        self.scaler_ss = scaler_ss
        self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma)
        if self.split == 'train':
            self.data,self.scaler_ss = model_scale(skp.MinMaxScaler(),
                                                np.asarray(self.feature_matrices[:]).reshape((len(targets),6,1595)),self.pre_updrs_off)
        else:
            self.data = (scale_feature_matrix(np.asarray(self.feature_matrices[:]).reshape((len(targets),6,1595)),pre_metric,self.scaler_ss)).reshape(len(targets),6*1596)
    
    def __len__(self):
        return len(self.feature_matrices)

    def __getitem__(self, index):
        case_dir = self.data_dir[index]
        X0 = self.data[index]  
        img = torch.unsqueeze(torch.unsqueeze(torch.from_numpy(np.float32(X0)),dim=1),dim=2)
        label = self.targets[index]
        weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)])
        return img, label, weight

    def __getscaler__(self):
        return(self.scaler_ss)
    
    def _prepare_weights(self, reweight, max_target=1, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
        assert reweight in {'none', 'inverse', 'sqrt_inv'}
        assert reweight != 'none' if lds else True, \
            "Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS"
        value_dict =  {x: 0 for x in np.round(np.linspace(0,max_target,11),1)}
        labels = self.targets
        for label in labels:
            value_dict[min(max_target, label)] += 1
        if reweight == 'sqrt_inv':
            value_dict = {k: np.sqrt(v) for k, v in value_dict.items()}
        elif reweight == 'inverse':
            value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()}  # clip weights for inverse re-weight
        num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels]
        if not len(num_per_label) or reweight == 'none':
            return None
        print(f"Using re-weighting: [{reweight.upper()}]")

        if lds:
            lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
            print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})')
            smoothed_value = convolve1d(
                np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant')
            num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels]

        weights = [np.float32(1 / x) for x in num_per_label]
        scaling = len(weights) / np.sum(weights)
        weights = [scaling * x for x in weights]
        return weights

    def __getslabels__(self, max_target=1, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
        value_dict = {x: 0 for x in np.round(np.linspace(0,max_target,11),1)}
        labels = self.targets
        for label in labels:
            value_dict[min(max_target, label)] += 1
        lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
        smoothed_value = convolve1d(
                np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant')
        return smoothed_value,  np.asarray([v for _, v in value_dict.items()])

In [23]:
train_dataset = QSM(data_dir=train_dir, mask_dir=train_seg, targets=y_train, nx=128, nz=128, split='train',
                          reweight=args.reweight, lds=args.lds, lds_kernel=args.lds_kernel, lds_ks=args.lds_ks, lds_sigma=args.lds_sigma)
val_dataset = QSM(data_dir=val_dir, mask_dir = val_seg, targets = y_val, nx=128, nz=128, split='val')
test_dataset = QSM(data_dir=test_dir, mask_dir = test_seg, targets = y_test, nx=128, nz=128, split='test')

In [24]:
train_dataset[0][0].shape

torch.Size([123, 123, 96])

: 

In [4]:
def main():
    if args.gpu is not None:
        print(f"Use GPU: {args.gpu} for training")

    # Data
    print('=====> Preparing data...')

    train_dataset = QSM(data_dir=train_dir, mask_dir=train_seg, targets=y_train, nx=128, nz=128, split='train',
                          reweight=args.reweight, lds=args.lds, lds_kernel=args.lds_kernel, lds_ks=args.lds_ks, lds_sigma=args.lds_sigma)
    val_dataset = QSM(data_dir=val_dir, mask_dir = val_seg, targets = y_val, nx=128, nz=128, split='val')
    test_dataset = QSM(data_dir=test_dir, mask_dir = test_seg, targets = y_test, nx=128, nz=128, split='test')

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=True, drop_last=False)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.workers, pin_memory=True, drop_last=False)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                             num_workers=args.workers, pin_memory=True, drop_last=False)
    print(f"Training data size: {len(train_dataset)}")
    print(f"Validation data size: {len(val_dataset)}")
    print(f"Test data size: {len(test_dataset)}")

    # Model
    print('=====> Building model...')
    model = resnet50(fds=args.fds, bucket_num=args.bucket_num, bucket_start=args.bucket_start,
                     start_update=args.start_update, start_smooth=args.start_smooth,
                     kernel=args.fds_kernel, ks=args.fds_ks, sigma=args.fds_sigma, momentum=args.fds_mmt)
    model = torch.nn.DataParallel(model).cuda()

    # evaluate only
    if args.evaluate:
        assert args.resume, 'Specify a trained model using [args.resume]'
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        print(f"===> Checkpoint '{args.resume}' loaded (epoch [{checkpoint['epoch']}]), testing...")
        validate(test_loader, model, train_labels=y_test, prefix='Test')
        return

    if args.retrain_fc:
        assert args.reweight != 'none' and args.pretrained
        print('===> Retrain last regression layer only!')
        for name, param in model.named_parameters():
            if 'fc' not in name and 'linear' not in name:
                param.requires_grad = False

    # Loss and optimizer
    if not args.retrain_fc:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.optimizer == 'adam' else \
            torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        # optimize only the last linear layer
        parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
        names = list(filter(lambda k: k is not None, [k if v.requires_grad else None for k, v in model.module.named_parameters()]))
        assert 1 <= len(parameters) <= 2  # fc.weight, fc.bias
        print(f'===> Only optimize parameters: {names}')
        optimizer = torch.optim.Adam(parameters, lr=args.lr) if args.optimizer == 'adam' else \
            torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.pretrained:
        checkpoint = torch.load(args.pretrained, map_location="cpu")
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in checkpoint['state_dict'].items():
            if 'linear' not in k and 'fc' not in k:
                new_state_dict[k] = v
        model.load_state_dict(new_state_dict, strict=False)
        print(f'===> Pretrained weights found in total: [{len(list(new_state_dict.keys()))}]')
        print(f'===> Pre-trained model loaded: {args.pretrained}')

    if args.resume:
        if os.path.isfile(args.resume):
            print(f"===> Loading checkpoint '{args.resume}'")
            checkpoint = torch.load(args.resume) if args.gpu is None else \
                torch.load(args.resume, map_location=torch.device(f'cuda:{str(args.gpu)}'))
            args.start_epoch = checkpoint['epoch']
            args.best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(f"===> Loaded checkpoint '{args.resume}' (Epoch [{checkpoint['epoch']}])")
        else:
            print(f"===> No checkpoint found at '{args.resume}'")

    globals()[f"weighted_{args.loss}_loss"] = loss.weighted_l1_loss
    cudnn.benchmark = True

    for epoch in range(args.start_epoch, args.epoch):
        adjust_learning_rate(optimizer, epoch, args)
        train_loss = train(train_loader, model, optimizer, epoch)
        val_loss_mse, val_loss_l1, val_loss_gmean = validate(val_loader, model, train_labels=y_train)
        loss_metric = val_loss_mse if args.loss == 'mse' else val_loss_l1
        is_best = loss_metric < args.best_loss
        args.best_loss = min(loss_metric, args.best_loss)
        print(f"Best {'L1' if 'l1' in args.loss else 'MSE'} Loss: {args.best_loss:.3f}")
        save_checkpoint(args, {
            'epoch': epoch + 1,
            'model': args.model,
            'best_loss': args.best_loss,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, is_best)
        print(f"Epoch #{epoch}: Train loss [{train_loss:.4f}]; "
              f"Val loss: MSE [{val_loss_mse:.4f}], L1 [{val_loss_l1:.4f}], G-Mean [{val_loss_gmean:.4f}]")

    # test with best checkpoint
    print("=" * 120)
    print("Test best model on testset...")
    args.store_name = str(str(epoch),'_',str(datetime.now()))
    checkpoint = torch.load(f"{args.store_root}/{args.store_name}/ckpt.best.pth.tar")
    model.load_state_dict(checkpoint['state_dict'])
    print(f"Loaded best model, epoch {checkpoint['epoch']}, best val loss {checkpoint['best_loss']:.4f}")
    test_loss_mse, test_loss_l1, test_loss_gmean = validate(test_loader, model, train_labels=y_test, prefix='Test')
    print(f"Test loss: MSE [{test_loss_mse:.4f}], L1 [{test_loss_l1:.4f}], G-Mean [{test_loss_gmean:.4f}]\nDone")

def train(train_loader, model, optimizer, epoch):
    batch_time = AverageMeter('Time', ':6.2f')
    data_time = AverageMeter('Data', ':6.4f')
    losses = AverageMeter(f'Loss ({args.loss.upper()})', ':.3f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses],
        prefix="Epoch: [{}]".format(epoch)
    )

    model.train()
    end = time.time()
    for idx, (inputs, targets, weights) in enumerate(train_loader):
        data_time.update(time.time() - end)
        inputs, targets, weights = \
            inputs.cuda(non_blocking=True), targets.cuda(non_blocking=True), weights.cuda(non_blocking=True)
        print(inputs.shape)
        print(inputs[:,0,0,0])
        print(targets)
        print(weights)
        if args.fds:
            outputs, _ = model(inputs, targets, epoch)
        else:
            outputs = model(inputs, targets, epoch)

        loss = globals()[f"weighted_{args.loss}_loss"](outputs, torch.unsqueeze(targets,dim=1), weights)
        assert not (np.isnan(loss.item()) or loss.item() > 1e6), f"Loss explosion: {loss.item()}"

        losses.update(loss.item(), inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()
        if idx % args.print_freq == 0:
            progress.display(idx)

    if args.fds and epoch >= args.start_update:
        print(f"Create Epoch [{epoch}] features of all training data...")
        encodings, labels = [], []
        with torch.no_grad():
            for (inputs, targets, _) in tqdm(train_loader):
                inputs = inputs.cuda(non_blocking=True)
                outputs, feature = model(inputs, targets, epoch)
                encodings.extend(feature.data.squeeze().cpu().numpy())
                labels.extend(targets.data.squeeze().cpu().numpy())

        encodings, labels = torch.from_numpy(np.vstack(encodings)).cuda(), torch.from_numpy(np.hstack(labels)).cuda()
        model.module.FDS.update_last_epoch_stats(epoch)
        model.module.FDS.update_running_stats(encodings, labels, epoch)

    return losses.avg


def validate(val_loader, model, train_labels=None, prefix='Val'):
    batch_time = AverageMeter('Time', ':6.3f')
    losses_mse = AverageMeter('Loss (MSE)', ':.3f')
    losses_l1 = AverageMeter('Loss (L1)', ':.3f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses_mse, losses_l1],
        prefix=f'{prefix}: '
    )

    criterion_mse = nn.MSELoss()
    criterion_l1 = nn.L1Loss()
    criterion_gmean = nn.L1Loss(reduction='none')

    model.eval()
    losses_all = []
    preds, labels = [], []
    with torch.no_grad():
        end = time.time()
        for idx, (inputs, targets, _) in enumerate(val_loader):
            inputs, targets = inputs.cuda(non_blocking=True), targets.cuda(non_blocking=True)
            outputs = model(inputs)

            preds.extend(outputs.data.cpu().numpy())
            labels.extend(targets.data.cpu().numpy())
            loss_mse = criterion_mse(outputs, torch.unsqueeze(targets,dim=1))
            loss_l1 = criterion_l1(outputs, torch.unsqueeze(targets,dim=1))
            loss_all = criterion_gmean(outputs, torch.unsqueeze(targets,dim=1))
            losses_all.extend(loss_all.cpu().numpy())

            losses_mse.update(loss_mse.item(), inputs.size(0))
            losses_l1.update(loss_l1.item(), inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()
            if idx % args.print_freq == 0:
                progress.display(idx)
        # print('Validate labels: ',str(labels))
        print('Calculating shot metrics for validation predictions of size ',str(len(preds)),' with validations labels of size ',str(len(labels)), ' and training labels of size ', str(len(train_labels)))
        # print('Train labels: ',str(train_labels))
        shot_dict = shot_metrics(np.hstack(preds), np.hstack(labels), train_labels)
        loss_gmean = gmean(np.hstack(losses_all), axis=None).astype(float)
        print(f" * Overall: MSE {losses_mse.avg:.3f}\tL1 {losses_l1.avg:.3f}\tG-Mean {loss_gmean:.3f}")
        print(f" * Many: MSE {shot_dict['many']['mse']:.3f}\t"
              f"L1 {shot_dict['many']['l1']:.3f}\tG-Mean {shot_dict['many']['gmean']:.3f}")
        print(f" * Median: MSE {shot_dict['median']['mse']:.3f}\t"
              f"L1 {shot_dict['median']['l1']:.3f}\tG-Mean {shot_dict['median']['gmean']:.3f}")
        print(f" * Low: MSE {shot_dict['low']['mse']:.3f}\t"
              f"L1 {shot_dict['low']['l1']:.3f}\tG-Mean {shot_dict['low']['gmean']:.3f}")
        print('Predicted ',str(preds),' for true improvements ',str(targets))
    return losses_mse.avg, losses_l1.avg, loss_gmean


def shot_metrics(preds, labels, train_labels, many_shot_thr=5, low_shot_thr=2):
   
    if isinstance(preds, torch.Tensor):
        preds = preds.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
    elif isinstance(preds, np.ndarray):
        pass
    else:
        raise TypeError(f'Type ({type(preds)}) of predictions not supported')

    train_class_count, test_class_count = [], []
    mse_per_class, l1_per_class, l1_all_per_class = [], [], []
    #print('Unique validate labels: ',str(np.unique(labels)))
    for l in np.unique(labels):
        # print('Looking for label ',str(l),' in train labels: ',str(train_labels))
        train_class_count.append(len(train_labels[train_labels == l]))
        #print(train_labels)
        #print(train_labels == l)
        # print('Looking for label ',str(l),' in test labels: ',str(labels))
        test_class_count.append(len(labels[labels == l]))
        # print('Train class count is ',str(train_class_count))
        # print('Test class count is ',str(test_class_count))
        mse_per_class.append(np.sum((preds[labels == l] - labels[labels == l]) ** 2))
        l1_per_class.append(np.sum(np.abs(preds[labels == l] - labels[labels == l])))
        l1_all_per_class.append(np.abs(preds[labels == l] - labels[labels == l]))

    many_shot_mse, median_shot_mse, low_shot_mse = [], [], []
    many_shot_l1, median_shot_l1, low_shot_l1 = [], [], []
    many_shot_gmean, median_shot_gmean, low_shot_gmean = [], [], []
    many_shot_cnt, median_shot_cnt, low_shot_cnt = [], [], []

    for i in range(len(train_class_count)):
        if train_class_count[i] > many_shot_thr:
            many_shot_mse.append(mse_per_class[i])
            many_shot_l1.append(l1_per_class[i])
            many_shot_gmean += list(l1_all_per_class[i])
            many_shot_cnt.append(test_class_count[i])
        elif train_class_count[i] < low_shot_thr:
            low_shot_mse.append(mse_per_class[i])
            low_shot_l1.append(l1_per_class[i])
            low_shot_gmean += list(l1_all_per_class[i])
            low_shot_cnt.append(test_class_count[i])
        else:
            median_shot_mse.append(mse_per_class[i])
            median_shot_l1.append(l1_per_class[i])
            median_shot_gmean += list(l1_all_per_class[i])
            median_shot_cnt.append(test_class_count[i])

    shot_dict = defaultdict(dict)
    shot_dict['many']['mse'] = np.sum(many_shot_mse) / np.sum(many_shot_cnt)
    shot_dict['many']['l1'] = np.sum(many_shot_l1) / np.sum(many_shot_cnt)
    shot_dict['many']['gmean'] = gmean(np.hstack(many_shot_gmean), axis=None).astype(float)
    shot_dict['median']['mse'] = np.sum(median_shot_mse) / np.sum(median_shot_cnt)
    shot_dict['median']['l1'] = np.sum(median_shot_l1) / np.sum(median_shot_cnt)
    shot_dict['median']['gmean'] = gmean(np.hstack(median_shot_gmean), axis=None).astype(float)
    shot_dict['low']['mse'] = np.sum(low_shot_mse) / np.sum(low_shot_cnt)
    shot_dict['low']['l1'] = np.sum(low_shot_l1) / np.sum(low_shot_cnt)
    shot_dict['low']['gmean'] = gmean(np.hstack(low_shot_gmean), axis=None).astype(float)

    return shot_dict


if __name__ == '__main__':
    main()



Use GPU: 0 for training
=====> Preparing data...
Training data size: 33
Validation data size: 9
Test data size: 5
=====> Building model...
torch.Size([3, 256, 256, 128])
tensor([-0.1345, -0.0630,  0.3004], device='cuda:0')
tensor([0.7000, 0.7000, 0.8000], device='cuda:0', dtype=torch.float64)
tensor([[1.],
        [1.],
        [1.]], device='cuda:0')


RuntimeError: Given groups=1, weight of size [128, 256, 7, 7, 7], expected input[1, 3, 256, 256, 128] to have 256 channels, but got 3 channels instead