Main.py

In [None]:
import timeit
from collections import OrderedDict

import torch
from torchvision import transforms, datasets

from A1_submission import logistic_regression, tune_hyper_parameter
from sklearn.model_selection import train_test_split

torch.multiprocessing.set_sharing_strategy('file_system')


def compute_score(acc, acc_thresh):
    min_thres, max_thres = acc_thresh
    if acc <= min_thres:
        base_score = 0.0
    elif acc >= max_thres:
        base_score = 100.0
    else:
        base_score = float(acc - min_thres) / (max_thres - min_thres) \
                     * 100
    return base_score

Submit.py

In [None]:
class MyIterableDataset(MNIST_dataset):
  def __init__(self, start, end):
      super(MyIterableDataset).__init__()
      assert end > start,
      self.start = start
      self.end = end

  def __iter__(self):
    return iter(range(self.start, self.end))
def worker_init_fn(worker_id):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset  # the dataset copy in this worker process
    overall_start = dataset.start
    overall_end = dataset.end
    # configure the dataset to only process the split workload
    per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
    worker_id = worker_info.id
    dataset.start = overall_start + worker_id * per_worker
    dataset.end = min(dataset.start + per_worker, overall_end)



In [None]:
MNIST_training = torchvision.datasets.MNIST('/MNIST_dataset/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((0.1307,), (0.3081,))]))

MNIST_test_set = torchvision.datasets.MNIST('/MNIST_dataset/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((0.1307,), (0.3081,))]))



# create a training and a validation set

MNIST_training_set,MNIST_validation_set = random_split(MNIST_training,[len(MNIST_training)-12000,12000])


#MNIST_training_set, MNIST_validation_set = random_split(MNIST_training, [55000, 5000])


train_loader = torch.utils.data.DataLoader(MNIST_training_set,batch_size=batch_size_train, shuffle=True)

validation_loader = torch.utils.data.DataLoader(MNIST_validation_set,batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(MNIST_test_set,batch_size=batch_size_test, shuffle=True)

In [None]:

def logistic_regression(dataset_name, device):
    # TODO: implement logistic regression here
    results = dict(
        model=None
    )

    return results


def tune_hyper_parameter(dataset_name, target_metric, device):
    # TODO: implement logistic regression hyper-parameter tuning here
    best_params = best_metric = None

    return best_params, best_metric


In [None]:
def test(
        model,
        dataset_name,
        device,

):
    if dataset_name == "MNIST":
        test_dataset = datasets.MNIST(
            root='./data',
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))]))

    elif dataset_name == "CIFAR10":
        test_dataset = datasets.CIFAR10(
            root='./data',
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))

    else:
        raise AssertionError(f"Invalid dataset: {dataset_name}")

    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=1, shuffle=True)

    model.eval()
    num_correct = 0
    total = 0
    for batch_idx, (data, targets) in enumerate(test_loader):
        data = data.to(device)
        targets = targets.to(device)

        with torch.no_grad():
            output = model(data)
            predicted = torch.argmax(output, dim=1)
            total += targets.size(0)
            num_correct += (predicted == targets).sum().item()

    acc = float(num_correct) / total
    return acc


In [None]:
class Args:
    """
    command-line arguments
    """

    """
    'MNIST': run on MNIST dataset (part 1)
    'CIFAR10': run on CIFAR10 dataset (part 2)
    """
    dataset = "MNIST"
    # dataset = "CIFAR10"

    """
    'logistic': run logistic regression on the specified dataset (parts 1 and 2)
    'tune': run hyper parameter tuning (part 3)
    """
    mode = 'logistic'
    # mode = 'tune'

    """
    metric with respect to which hyper parameters are to be tuned
    'acc': validation classification accuracy
    'loss': validation loss
    """
    target_metric = 'acc'
    # target_metric = 'loss'

    """
    set to 0 to run on cpu
    """
    gpu = 1

In [None]:
def main():
    args = Args()
    try:
        import paramparse
        paramparse.process(args)
    except ImportError:
        pass

    device = torch.device("cuda" if args.gpu and torch.cuda.is_available() else "cpu")

    acc_thresh = dict(
        MNIST=(0.84, 0.94),
        CIFAR10=(0.30, 0.40),
    )

    if args.mode == 'logistic':
        start = timeit.default_timer()
        results = logistic_regression(args.dataset, device)
        model = results['model']

        if model is None:
            print('model is None')
            return

        stop = timeit.default_timer()
        run_time = stop - start

        accuracy = test(
            model,
            args.dataset,
            device,
        )

        score = compute_score(accuracy, acc_thresh[args.dataset])
        result = OrderedDict(
            accuracy=accuracy,
            score=score,
            run_time=run_time
        )
        print(f"result on {args.dataset}:")
        for key in result:
            print(f"\t{key}: {result[key]}")
    elif args.mode == 'tune':
        start = timeit.default_timer()
        best_params, best_metric = tune_hyper_parameter(
            args.dataset, args.target_metric, device)
        stop = timeit.default_timer()
        run_time = stop - start
        print()
        print(f"Best {args.target_metric}: {best_metric:.4f}")
        print(f"Best params:\n{best_params}")
        print(f"runtime of tune_hyper_parameter: {run_time}")
    else:
        raise AssertionError(f'invalid mode: {args.mode}')


if __name__ == "__main__":
    main()
