In [1]:
import os,sys
sys.path.insert(0, r'~\trainer')
sys.path.insert(0, r'~\core')
import torch
from core.mean_teacher.run_context import RunContext
from core.mean_teacher.cli import parse_dict_args
from core import mean_teacher_main
import torchvision.transforms as transforms
from core.mean_teacher import data

In [2]:
def data_trans():
    train_transformation = data.TransformTwice(transforms.Compose([
        data.RandomTranslateWithReflect(4),
        transforms.RandomRotation(10),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ]))
    eval_transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(**channel_stats)
    ])

    return {
        'train_transformation': train_transformation,
        'eval_transformation': eval_transformation,
        'datadir': '', 
        'num_classes': n_labels
    }

def parameters():
    defaults = {
        'workers': 2,
        'checkpoint_epochs': 1,
        'arch': 'resnet_50',
        'consistency_type': 'mse',
        'consistency_rampup': 5,
        'consistency': 4.0,
        'weight_decay': 2e-4,
        'lr_rampup': 0,
        'base_lr': 0.001,
        'nesterov': True,
        'pretrained': True,
        
        'epochs': 300,
        'lr_rampdown_epochs': 400,
        'ema_decay': 0.99,
    }

    for data_seed in range(10, 110, 10):
        yield {
            **defaults,
            'title': '%s-label %s' % (data_seed, ds_name),
            'n_labels': n_labels,
            'data_seed': data_seed,
            'base_batch_size': bs,
            'base_labeled_batch_size': 0 if data_seed == 100 else int(bs/2),
            'exclude-unlabeled': True if data_seed == 100 else False,
            'train_subdir': '%s' % os.path.join(root,'train_%s.txt'% int(data_seed)),
            'eval_subdir': '%s' % os.path.join(root,'val_label_%s.txt'% int(data_seed)),
        }

def run(title, base_batch_size, base_labeled_batch_size, base_lr, n_labels, data_seed, **kwargs):
    ngpu = torch.cuda.device_count()
    assert ngpu > 0, "Expecting at least one GPU, found none."

    adapted_args = {
        'batch_size': base_batch_size * ngpu,
        'labeled_batch_size': base_labeled_batch_size * ngpu,
        'lr': base_lr * ngpu,
        'labels': os.path.join(root,'label_%s.txt'% int(data_seed)),
    }
    context = RunContext(ds_name, "c{}_r{}_s{}_lr{}_{}".format(n_labels, data_seed, base_batch_size, base_lr,lab))
    mean_teacher_main.args = parse_dict_args(**adapted_args, **kwargs)
    mean_teacher_main.main(context, data_trans())

## Clear_boundary

In [2]:
bs = 16
n_labels = 2
ds_name = 'model_path'

channel_stats = dict(mean=[0.2061, 0.1973, 0.1918],std=[0.2777, 0.2747, 0.2733]) # Copied from the output of data_preprocessing

root = r"~\data\imgs_feature_extraction_deep_learning\imgs_boundary_clear\train"

for run_params in parameters():
    run(**run_params)

## Surface_rough

In [3]:
bs = 16
n_labels = 2
ds_name = 'model_path'

channel_stats = dict(mean=[0.2061, 0.1973, 0.1918],std=[0.2777, 0.2747, 0.2733])

root = r"~\data\imgs_feature_extraction_deep_learning\imgs_surface_rough\train"

for run_params in parameters():
    run(**run_params)

## Bleeding

In [4]:
bs = 16
n_labels = 2
ds_name = 'model_path'

channel_stats = dict(mean=[0.2061, 0.1973, 0.1918],std=[0.2777, 0.2747, 0.2733])

root = r"~\data\imgs_feature_extraction_deep_learning\imgs_bleeding\train"

for run_params in parameters():
    run(**run_params)

## Tone

In [5]:
bs = 16
n_labels = 3
ds_name = 'model_path'

channel_stats = dict(mean=[0.2061, 0.1973, 0.1918],std=[0.2777, 0.2747, 0.2733])

root = r"~\data\imgs_feature_extraction_deep_learning\imgs_tone\train"

for run_params in parameters():
    run(**run_params)

## Elevated

In [6]:
bs = 16
n_labels = 2
ds_name = 'model_path'

channel_stats = dict(mean=[0.2061, 0.1973, 0.1918],std=[0.2777, 0.2747, 0.2733])

root = r"~\data\imgs_feature_extraction_deep_learning\imgs_elevated\train"

for run_params in parameters():
    run(**run_params)

### Depressed

In [7]:
bs = 16
n_labels = 2
ds_name = 'model_path'

channel_stats = dict(mean=[0.2061, 0.1973, 0.1918],std=[0.2777, 0.2747, 0.2733])

root = r"~\data\imgs_feature_extraction_deep_learning\imgs_depressed\train"

for run_params in parameters():
    run(**run_params)