algorithm

In [18]:
# cd JoCoR/

In [24]:
# cd JoCoR/

/home/subrat/JoCoR-env/JoCoR


In [20]:
#!pip install torch
# !pip install torchvision

In [21]:
# !pip install numpy
# !pip install matplotlib

In [87]:
# -*- coding:utf-8 -*-
import torch
import torch.nn.functional as F
from torch.autograd import Variable

from JoCoR.model.cnn import MLPNet,CNN
import numpy as np
from JoCoR.common.utils import accuracy

from JoCoR.algorithm.loss import loss_jocor


class JoCoR:
    def __init__(self, args, train_dataset, device, input_channel, num_classes):

        # Hyper Parameters
        self.batch_size = 128
        learning_rate = args.lr

#         if args.forget_rate is None:
#             if args.noise_type == "asymmetric":
#                 forget_rate = args.noise_rate / 2
#             else:
#                 forget_rate = args.noise_rate
#         else:
#             forget_rate = args.forget_rate

#         self.noise_or_not = train_dataset.noise_or_not

        # Adjust learning rate and betas for Adam Optimizer
        mom1 = 0.9
        mom2 = 0.1
        self.alpha_plan = [learning_rate] * args.n_epoch
        self.beta1_plan = [mom1] * args.n_epoch

        for i in range(args.epoch_decay_start, args.n_epoch):
            self.alpha_plan[i] = float(args.n_epoch - i) / (args.n_epoch - args.epoch_decay_start) * learning_rate
            self.beta1_plan[i] = mom2

        # define drop rate schedule
        self.rate_schedule = np.ones(args.n_epoch) * forget_rate
        self.rate_schedule[:args.num_gradual] = np.linspace(0, forget_rate ** args.exponent, args.num_gradual)

        self.device = device
        self.num_iter_per_epoch = args.num_iter_per_epoch
        self.print_freq = args.print_freq
        self.co_lambda = args.co_lambda
        self.n_epoch = args.n_epoch
        self.train_dataset = train_dataset

        if args.model_type == "googlenet":
            self.model1 = torchvision.models.googlenet(pretrained=True)
            self.model2 = torchvision.models.googlenet(pretrained=True)
        elif args.model_type == "mlp":
            self.model1 = MLPnet()
            self.model2 = MLPnet()

        self.model1.to(device)
        print(self.model1.parameters)

        self.model2.to(device)
        print(self.model2.parameters)

        self.optimizer = torch.optim.Adam(list(self.model1.parameters()) + list(self.model2.parameters()),
                                          lr=learning_rate)

        self.loss_fn = loss_jocor


        self.adjust_lr = args.adjust_lr

    # Evaluate the Model
    def evaluate(self, test_loader):
        print('Evaluating ...')
        self.model1.eval()  # Change model to 'eval' mode.
        self.model2.eval()  # Change model to 'eval' mode

        correct1 = 0
        total1 = 0
        for images, labels in test_loader:
            images = Variable(images).to(self.device)
            logits1 = self.model1(images)
            outputs1 = F.softmax(logits1, dim=1)
            _, pred1 = torch.max(outputs1.data, 1)
            total1 += labels.size(0)
            correct1 += (pred1.cpu() == labels).sum()

        correct2 = 0
        total2 = 0
        for images, labels, _ in test_loader:
            images = Variable(images).to(self.device)
            logits2 = self.model2(images)
            outputs2 = F.softmax(logits2, dim=1)
            _, pred2 = torch.max(outputs2.data, 1)
            total2 += labels.size(0)
            correct2 += (pred2.cpu() == labels).sum()

        acc1 = 100 * float(correct1) / float(total1)
        acc2 = 100 * float(correct2) / float(total2)
        return acc1, acc2

    # Train the Model
    def train(self, train_loader, epoch):
        print('Training ...')
        self.model1.train()  # Change model to 'train' mode.
        self.model2.train()  # Change model to 'train' mode

        if self.adjust_lr == 1:
            self.adjust_learning_rate(self.optimizer, epoch)

        train_total = 0
        train_correct = 0
        train_total2 = 0
        train_correct2 = 0
        pure_ratio_1_list = []
        pure_ratio_2_list = []

        for i, (images, labels, indexes) in enumerate(train_loader):
            ind = indexes.cpu().numpy().transpose()
            if i > self.num_iter_per_epoch:
                break

            images = Variable(images).to(self.device)
            labels = Variable(labels).to(self.device)

            # Forward + Backward + Optimize
            logits1 = self.model1(images)
            prec1 = accuracy(logits1, labels, topk=(1,))
            train_total += 1
            train_correct += prec1

            logits2 = self.model2(images)
            prec2 = accuracy(logits2, labels, topk=(1,))
            train_total2 += 1
            train_correct2 += prec2

            loss_1, loss_2, pure_ratio_1, pure_ratio_2 = self.loss_fn(logits1, logits2, labels, self.rate_schedule[epoch],
                                                                 ind, self.noise_or_not, self.co_lambda)

            self.optimizer.zero_grad()
            loss_1.backward()
            self.optimizer.step()

            pure_ratio_1_list.append(100 * pure_ratio_1)
            pure_ratio_2_list.append(100 * pure_ratio_2)

            if (i + 1) % self.print_freq == 0:
                print(
                    'Epoch [%d/%d], Iter [%d/%d] Training Accuracy1: %.4F, Training Accuracy2: %.4f, Loss1: %.4f, Loss2: %.4f, Pure Ratio1 %.4f %% Pure Ratio2 %.4f %%'
                    % (epoch + 1, self.n_epoch, i + 1, len(self.train_dataset) // self.batch_size, prec1, prec2,
                       loss_1.data.item(), loss_2.data.item(), sum(pure_ratio_1_list) / len(pure_ratio_1_list), sum(pure_ratio_2_list) / len(pure_ratio_2_list)))

        train_acc1 = float(train_correct) / float(train_total)
        train_acc2 = float(train_correct2) / float(train_total2)
        return train_acc1, train_acc2, pure_ratio_1_list, pure_ratio_2_list

    def adjust_learning_rate(self, optimizer, epoch):
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.alpha_plan[epoch]
            param_group['betas'] = (self.beta1_plan[epoch], 0.999)  # Only change beta1

loss.py

In [88]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F

def kl_loss_compute(pred, soft_targets, reduce=True):

    kl = F.kl_div(F.log_softmax(pred, dim=1),F.softmax(soft_targets, dim=1),reduce=False)

    if reduce:
        return torch.mean(torch.sum(kl, dim=1))
    else:
        return torch.sum(kl, 1)




def loss_jocor(y_1, y_2, t, forget_rate, ind, noise_or_not, co_lambda=0.1):

    loss_pick_1 = F.cross_entropy(y_1, t, reduce = False) * (1-co_lambda)
    loss_pick_2 = F.cross_entropy(y_2, t, reduce = False) * (1-co_lambda)
    loss_pick = (loss_pick_1 + loss_pick_2 + co_lambda * kl_loss_compute(y_1, y_2,reduce=False) + co_lambda * kl_loss_compute(y_2, y_1, reduce=False)).cpu()


    ind_sorted = np.argsort(loss_pick.data)
    loss_sorted = loss_pick[ind_sorted]

    remember_rate = 1 - forget_rate
    num_remember = int(remember_rate * len(loss_sorted))

    pure_ratio = np.sum(noise_or_not[ind[ind_sorted[:num_remember]]])/float(num_remember)

    ind_update=ind_sorted[:num_remember]

    # exchange
    loss = torch.mean(loss_pick[ind_update])

    return loss, loss, pure_ratio, pure_ratio



common

In [89]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F



def plot_result(accuracy_list,pure_ratio_list,name="test.png"):
    plt.figure(figsize=(16, 6))
    plt.subplot(1, 2, 1)
    plt.plot(accuracy_list, label='test_accuracy')
    plt.subplot(1, 2, 2)
    plt.plot(pure_ratio_list, label='test_pure_ratio')
    plt.savefig(name)
    plt.show()


def accuracy(logit, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    output = F.softmax(logit, dim=1)
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res[0]

Data

cifar.py

In [90]:
#!pip uninstall noisify

In [91]:
# import os
# import os.path
# import copy
# import hashlib
# import errno
# import numpy as np
# from numpy.testing import assert_array_almost_equal

# # def check_integrity(fpath, md5):
# #     if not os.path.isfile(fpath):
# #         return False
# #     md5o = hashlib.md5()
# #     with open(fpath, 'rb') as f:
# #         # read in 1MB chunks
# #         for chunk in iter(lambda: f.read(1024 * 1024), b''):
# #             md5o.update(chunk)
# #     md5c = md5o.hexdigest()
# #     if md5c != md5:
# #         return False
# #     return True


# # def download_url(url, root, filename, md5):
# #     from six.moves import urllib

# #     root = os.path.expanduser(root)
# #     fpath = os.path.join(root, filename)

# #     try:
# #         os.makedirs(root)
# #     except OSError as e:
# #         if e.errno == errno.EEXIST:
# #             pass
# #         else:
# #             raise

# #     # downloads file
# #     if os.path.isfile(fpath) and check_integrity(fpath, md5):
# #         print('Using downloaded and verified file: ' + fpath)
# #     else:
# #         try:
# #             print('Downloading ' + url + ' to ' + fpath)
# #             urllib.request.urlretrieve(url, fpath)
# #         except:
# #             if url[:5] == 'https':
# #                 url = url.replace('https:', 'http:')
# #                 print('Failed download. Trying https -> http instead.'
# #                       ' Downloading ' + url + ' to ' + fpath)
# #                 urllib.request.urlretrieve(url, fpath)


# def list_dir(root, prefix=False):
#     """List all directories at a given root

#     Args:
#         root (str): Path to directory whose folders need to be listed
#         prefix (bool, optional): If true, prepends the path to each result, otherwise
#             only returns the name of the directories found
#     """
#     root = os.path.expanduser(root)
#     directories = list(
#         filter(
#             lambda p: os.path.isdir(os.path.join(root, p)),
#             os.listdir(root)
#         )
#     )

#     if prefix is True:
#         directories = [os.path.join(root, d) for d in directories]

#     return directories


# def list_files(root, suffix, prefix=False):
#     """List all files ending with a suffix at a given root

#     Args:
#         root (str): Path to directory whose folders need to be listed
#         suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
#             It uses the Python "str.endswith" method and is passed directly
#         prefix (bool, optional): If true, prepends the path to each result, otherwise
#             only returns the name of the files found
#     """
#     root = os.path.expanduser(root)
#     files = list(
#         filter(
#             lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
#             os.listdir(root)
#         )
#     )

#     if prefix is True:
#         files = [os.path.join(root, d) for d in files]

#     return files


# def build_for_cifar100(size, noise):
#     """ The noise matrix flips to the "next" class with probability 'noise'.
#     """

#     assert(noise >= 0.) and (noise <= 1.)

#     P = (1. - noise) * np.eye(size)
#     for i in np.arange(size - 1):
#         P[i, i+1] = noise

#     # adjust last row
#     P[size-1, 0] = noise

#     assert_array_almost_equal(P.sum(axis=1), 1, 1)
#     return P


# # basic function
# def multiclass_noisify(y, P, random_state=0):
#     """ Flip classes according to transition probability matrix T.
#     It expects a number between 0 and the number of classes - 1.
#     """
#     print(np.max(y), P.shape[0])
#     assert P.shape[0] == P.shape[1]
#     assert np.max(y) < P.shape[0]

#     # row stochastic matrix
#     assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
#     assert (P >= 0.0).all()

#     m = y.shape[0]
#     print(m)
#     new_y = y.copy()
#     flipper = np.random.RandomState(random_state)

#     for idx in np.arange(m):
#         i = y[idx]
#         # draw a vector with only an 1
#         flipped = flipper.multinomial(1, P[i, :][0], 1)[0]
#         new_y[idx] = np.where(flipped == 1)[0]

#     return new_y


# # noisify_pairflip call the function "multiclass_noisify"
# def noisify_pairflip(y_train, noise, random_state=None, nb_classes=10):
#     """mistakes:
#         flip in the pair
#     """
#     P = np.eye(nb_classes)
#     n = noise

#     if n > 0.0:
#         # 0 -> 1
#         P[0, 0], P[0, 1] = 1. - n, n
#         for i in range(1, nb_classes-1):
#             P[i, i], P[i, i + 1] = 1. - n, n
#         P[nb_classes-1, nb_classes-1], P[nb_classes-1, 0] = 1. - n, n

#         y_train_noisy = multiclass_noisify(y_train, P=P,
#                                            random_state=random_state)
#         actual_noise = (y_train_noisy != y_train).mean()
#         assert actual_noise > 0.0
#         print('Actual noise %.2f' % actual_noise)
#         y_train = y_train_noisy
#     print(P)

#     return y_train, actual_noise

# def noisify_multiclass_symmetric(y_train, noise, random_state=None, nb_classes=10):
#     """mistakes:
#         flip in the symmetric way
#     """
#     P = np.ones((nb_classes, nb_classes))
#     n = noise
#     P = (n / (nb_classes - 1)) * P

#     if n > 0.0:
#         # 0 -> 1
#         P[0, 0] = 1. - n
#         for i in range(1, nb_classes-1):
#             P[i, i] = 1. - n
#         P[nb_classes-1, nb_classes-1] = 1. - n

#         y_train_noisy = multiclass_noisify(y_train, P=P,
#                                            random_state=random_state)
#         actual_noise = (y_train_noisy != y_train).mean()
#         assert actual_noise > 0.0
#         print('Actual noise %.2f' % actual_noise)
#         y_train = y_train_noisy
#     print(P)

#     return y_train, actual_noise

# def noisify_mnist_asymmetric(y_train, noise, random_state=None):
#     """mistakes:
#         1 <- 7
#         2 -> 7
#         3 -> 8
#         5 <-> 6
#     """
#     nb_classes = 10
#     P = np.eye(nb_classes)
#     n = noise

#     if n > 0.0:
#         # 1 <- 7
#         P[7, 7], P[7, 1] = 1. - n, n

#         # 2 -> 7
#         P[2, 2], P[2, 7] = 1. - n, n

#         # 5 <-> 6
#         P[5, 5], P[5, 6] = 1. - n, n
#         P[6, 6], P[6, 5] = 1. - n, n

#         # 3 -> 8
#         P[3, 3], P[3, 8] = 1. - n, n

#         y_train_noisy = multiclass_noisify(y_train, P=P,
#                                            random_state=random_state)
#         actual_noise = (y_train_noisy != y_train).mean()
#         assert actual_noise > 0.0
#         print('Actual noise %.2f' % actual_noise)
#         y_train = y_train_noisy

#     print(P)

#     return y_train, P


# # def noisify_cifar10_asymmetric(y_train, noise, random_state=None):
# #     """mistakes:
# #         automobile <- truck
# #         bird -> airplane
# #         cat <-> dog
# #         deer -> horse
# #     """
# #     nb_classes = 10
# #     P = np.eye(nb_classes)
# #     n = noise

# #     if n > 0.0:
# #         # automobile <- truck
# #         P[9, 9], P[9, 1] = 1. - n, n

# #         # bird -> airplane
# #         P[2, 2], P[2, 0] = 1. - n, n

# #         # cat <-> dog
# #         P[3, 3], P[3, 5] = 1. - n, n
# #         P[5, 5], P[5, 3] = 1. - n, n

# #         # automobile -> truck
# #         P[4, 4], P[4, 7] = 1. - n, n

# #         y_train_noisy = multiclass_noisify(y_train, P=P,
# #                                            random_state=random_state)
# #         actual_noise = (y_train_noisy != y_train).mean()
# #         assert actual_noise > 0.0
# #         print('Actual noise %.2f' % actual_noise)
# #         y_train = y_train_noisy

# #     return y_train, P


# # def noisify_cifar100_asymmetric(y_train, noise, random_state=None):
# #     """mistakes are inside the same superclass of 10 classes, e.g. 'fish'
# #     """
# #     nb_classes = 100
# #     P = np.eye(nb_classes)
# #     n = noise
# #     nb_superclasses = 20
# #     nb_subclasses = 5

# #     if n > 0.0:
# #         for i in np.arange(nb_superclasses):
# #             init, end = i * nb_subclasses, (i+1) * nb_subclasses
# #             P[init:end, init:end] = build_for_cifar100(nb_subclasses, n)

# #         y_train_noisy = multiclass_noisify(y_train, P=P,
# #                                            random_state=random_state)
# #         actual_noise = (y_train_noisy != y_train).mean()
# #         assert actual_noise > 0.0
# #         print('Actual noise %.2f' % actual_noise)
# #         y_train = y_train_noisy

# #     return y_train, P


def noisify(dataset='mnist', nb_classes=10, train_labels=None, noise_type=None, noise_rate=0, random_state=0):
    if noise_type == 'pairflip':
        train_noisy_labels, actual_noise_rate = noisify_pairflip(train_labels, noise_rate, random_state=random_state, nb_classes=nb_classes)
    if noise_type == 'symmetric':
        train_noisy_labels, actual_noise_rate = noisify_multiclass_symmetric(train_labels, noise_rate, random_state=random_state, nb_classes=nb_classes)
# #     if noise_type == 'asymmetric':
# #         if dataset == 'mnist':
# #             train_noisy_labels, actual_noise_rate = noisify_mnist_asymmetric(train_labels, noise_rate, random_state=random_state)
# #         elif dataset == 'cifar10':
# #             train_noisy_labels, actual_noise_rate = noisify_cifar10_asymmetric(train_labels, noise_rate, random_state=random_state)
# #         elif dataset == 'cifar100':
# #             train_noisy_labels, actual_noise_rate = noisify_cifar100_asymmetric(train_labels, noise_rate, random_state=random_state)
#     return train_noisy_labels, actual_noise_rate

In [92]:
pwd

'/home/subrat'

In [93]:
# from __future__ import print_function
# from PIL import Image
# import os
# import os.path
# import numpy as np
# import sys

# if sys.version_info[0] == 2:
#     import cPickle as pickle
# else:
#     import pickle

# import torch.utils.data as data
# from utils import noisify

# class CIFAR10(data.Dataset):
#     """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

#     Args:
#         root (string): Root directory of dataset where directory
#             ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
#         train (bool, optional): If True, creates dataset from training set, otherwise
#             creates from test set.
#         transform (callable, optional): A function/transform that  takes in an PIL image
#             and returns a transformed version. E.g, ``transforms.RandomCrop``
#         target_transform (callable, optional): A function/transform that takes in the
#             target and transforms it.
#         download (bool, optional): If true, downloads the dataset from the internet and
#             puts it in root directory. If dataset is already downloaded, it is not
#             downloaded again.

#     """
# #     base_folder = 'cifar-10-batches-py'
# #     url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
# #     filename = "cifar-10-python.tar.gz"
# #     tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
# #     train_list = [
# #         ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
# #         ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
# #         ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
# #         ['data_batch_4', '634d18415352ddfa80567beed471001a'],
# #         ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
# #     ]

# #     test_list = [
# #         ['test_batch', '40351d587109b95175f43aff81a1287e'],
# #     ]

#     def __init__(self, root, train=True,
#                  transform=None, target_transform=None,
#                  download=False,
#                  noise_type=None, noise_rate=0.2, random_state=0):
#         self.root = os.path.expanduser(root)
#         self.transform = transform
#         self.target_transform = target_transform
#         self.train = train  # training set or test set
#         self.dataset='cifar10'
#         self.noise_type=noise_type
#         self.nb_classes=10

# #         if download:
# #             self.download()

# #         if not self._check_integrity():
# #             raise RuntimeError('Dataset not found or corrupted.' +
# #                                ' You can use download=True to download it')

#         # now load the picked numpy arrays
#         if self.train:
#             self.train_data = []
#             self.train_labels = []
#             for fentry in self.train_list:
#                 f = fentry[0]
#                 file = os.path.join(self.root, self.base_folder, f)
#                 fo = open(file, 'rb')
#                 if sys.version_info[0] == 2:
#                     entry = pickle.load(fo)
#                 else:
#                     entry = pickle.load(fo, encoding='latin1')
#                 self.train_data.append(entry['data'])
#                 if 'labels' in entry:
#                     self.train_labels += entry['labels']
#                 else:
#                     self.train_labels += entry['fine_labels']
#                 fo.close()

#             self.train_data = np.concatenate(self.train_data)
#             self.train_data = self.train_data.reshape((50000, 3, 32, 32))
#             self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
#             #if noise_type is not None:
#             if noise_type !='clean':
#                 # noisify train data
#                 self.train_labels=np.asarray([[self.train_labels[i]] for i in range(len(self.train_labels))])
#                 self.train_noisy_labels, self.actual_noise_rate = noisify(dataset=self.dataset, train_labels=self.train_labels, noise_type=noise_type, noise_rate=noise_rate, random_state=random_state, nb_classes=self.nb_classes)
#                 self.train_noisy_labels=[i[0] for i in self.train_noisy_labels]
#                 _train_labels=[i[0] for i in self.train_labels]
#                 self.noise_or_not = np.transpose(self.train_noisy_labels)==np.transpose(_train_labels)
#         else:
#             f = self.test_list[0][0]
#             file = os.path.join(self.root, self.base_folder, f)
#             fo = open(file, 'rb')
#             if sys.version_info[0] == 2:
#                 entry = pickle.load(fo)
#             else:
#                 entry = pickle.load(fo, encoding='latin1')
#             self.test_data = entry['data']
#             if 'labels' in entry:
#                 self.test_labels = entry['labels']
#             else:
#                 self.test_labels = entry['fine_labels']
#             fo.close()
#             self.test_data = self.test_data.reshape((10000, 3, 32, 32))
#             self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC

#     def __getitem__(self, index):
#         """
#         Args:
#             index (int): Index

#         Returns:
#             tuple: (image, target) where target is index of the target class.
#         """
#         if self.train:
#             if self.noise_type !='clean':
#                 img, target = self.train_data[index], self.train_noisy_labels[index]
#             else:
#                 img, target = self.train_data[index], self.train_labels[index]
#         else:
#             img, target = self.test_data[index], self.test_labels[index]

#         # doing this so that it is consistent with all other datasets
#         # to return a PIL Image
#         img = Image.fromarray(img)

#         if self.transform is not None:
#             img = self.transform(img)

#         if self.target_transform is not None:
#             target = self.target_transform(target)

#         return img, target, index

#     def __len__(self):
#         if self.train:
#             return len(self.train_data)
#         else:
#             return len(self.test_data)

#     def _check_integrity(self):
#         root = self.root
#         for fentry in (self.train_list + self.test_list):
#             filename, md5 = fentry[0], fentry[1]
#             fpath = os.path.join(root, self.base_folder, filename)
#             if not check_integrity(fpath, md5):
#                 return False
#         return True

#     def download(self):
#         import tarfile

#         if self._check_integrity():
#             print('Files already downloaded and verified')
#             return

#         root = self.root
#         download_url(self.url, root, self.filename, self.tgz_md5)

#         # extract file
#         cwd = os.getcwd()
#         tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
#         os.chdir(root)
#         tar.extractall()
#         tar.close()
#         os.chdir(cwd)

#     def __repr__(self):
#         fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
#         fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
#         tmp = 'train' if self.train is True else 'test'
#         fmt_str += '    Split: {}\n'.format(tmp)
#         fmt_str += '    Root Location: {}\n'.format(self.root)
#         tmp = '    Transforms (if any): '
#         fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
#         tmp = '    Target Transforms (if any): '
#         fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
#         return fmt_str

# class CIFAR100(data.Dataset):
#     """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

#     Args:
#         root (string): Root directory of dataset where directory
#             ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
#         train (bool, optional): If True, creates dataset from training set, otherwise
#             creates from test set.
#         transform (callable, optional): A function/transform that  takes in an PIL image
#             and returns a transformed version. E.g, ``transforms.RandomCrop``
#         target_transform (callable, optional): A function/transform that takes in the
#             target and transforms it.
#         download (bool, optional): If true, downloads the dataset from the internet and
#             puts it in root directory. If dataset is already downloaded, it is not
#             downloaded again.

#     """
#     base_folder = 'cifar-100-python'
#     url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
#     filename = "cifar-100-python.tar.gz"
#     tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
#     train_list = [
#         ['train', '16019d7e3df5f24257cddd939b257f8d'],
#     ]

#     test_list = [
#         ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
#     ]
 

#     def __init__(self, root, train=True,
#                  transform=None, target_transform=None,
#                  download=False,
#                  noise_type=None, noise_rate=0.2, random_state=0):
#         self.root = os.path.expanduser(root)
#         self.transform = transform
#         self.target_transform = target_transform
#         self.train = train  # training set or test set
#         self.dataset='cifar100'
#         self.noise_type=noise_type
#         self.nb_classes=100

#         if download:
#             self.download()

#         if not self._check_integrity():
#             raise RuntimeError('Dataset not found or corrupted.' +
#                                ' You can use download=True to download it')

#         # now load the picked numpy arrays
#         if self.train:
#             self.train_data = []
#             self.train_labels = []
#             for fentry in self.train_list:
#                 f = fentry[0]
#                 file = os.path.join(self.root, self.base_folder, f)
#                 fo = open(file, 'rb')
#                 if sys.version_info[0] == 2:
#                     entry = pickle.load(fo)
#                 else:
#                     entry = pickle.load(fo, encoding='latin1')
#                 self.train_data.append(entry['data'])
#                 if 'labels' in entry:
#                     self.train_labels += entry['labels']
#                 else:
#                     self.train_labels += entry['fine_labels']
#                 fo.close()

#             self.train_data = np.concatenate(self.train_data)
#             self.train_data = self.train_data.reshape((50000, 3, 32, 32))
#             self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
#             if noise_type is not None:
#                 # noisify train data
#                 self.train_labels=np.asarray([[self.train_labels[i]] for i in range(len(self.train_labels))])
#                 self.train_noisy_labels, self.actual_noise_rate = noisify(dataset=self.dataset, train_labels=self.train_labels, noise_type=noise_type, noise_rate=noise_rate, random_state=random_state, nb_classes=self.nb_classes)
#                 self.train_noisy_labels=[i[0] for i in self.train_noisy_labels]
#                 _train_labels=[i[0] for i in self.train_labels]
#                 self.noise_or_not = np.transpose(self.train_noisy_labels)==np.transpose(_train_labels)
#         else:
#             f = self.test_list[0][0]
#             file = os.path.join(self.root, self.base_folder, f)
#             fo = open(file, 'rb')
#             if sys.version_info[0] == 2:
#                 entry = pickle.load(fo)
#             else:
#                 entry = pickle.load(fo, encoding='latin1')
#             self.test_data = entry['data']
#             if 'labels' in entry:
#                 self.test_labels = entry['labels']
#             else:
#                 self.test_labels = entry['fine_labels']
#             fo.close()
#             self.test_data = self.test_data.reshape((10000, 3, 32, 32))
#             self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC

#     def __getitem__(self, index):
#         """
#         Args:
#             index (int): Index

#         Returns:
#             tuple: (image, target) where target is index of the target class.
#         """
#         if self.train:
#             if self.noise_type is not None:
#                 img, target = self.train_data[index], self.train_noisy_labels[index]
#             else:
#                 img, target = self.train_data[index], self.train_labels[index]
#         else:
#             img, target = self.test_data[index], self.test_labels[index]

#         # doing this so that it is consistent with all other datasets
#         # to return a PIL Image
#         img = Image.fromarray(img)

#         if self.transform is not None:
#             img = self.transform(img)

#         if self.target_transform is not None:
#             target = self.target_transform(target)

#         return img, target, index

#     def __len__(self):
#         if self.train:
#             return len(self.train_data)
#         else:
#             return len(self.test_data)

#     def _check_integrity(self):
#         root = self.root
#         for fentry in (self.train_list + self.test_list):
#             filename, md5 = fentry[0], fentry[1]
#             fpath = os.path.join(root, self.base_folder, filename)
#             if not check_integrity(fpath, md5):
#                 return False
#         return True

#     def download(self):
#         import tarfile

#         if self._check_integrity():
#             print('Files already downloaded and verified')
#             return

#         root = self.root
#         download_url(self.url, root, self.filename, self.tgz_md5)

#         # extract file
#         cwd = os.getcwd()
#         tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
#         os.chdir(root)
#         tar.extractall()
#         tar.close()
#         os.chdir(cwd)

#     def __repr__(self):
#         fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
#         fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
#         tmp = 'train' if self.train is True else 'test'
#         fmt_str += '    Split: {}\n'.format(tmp)
#         fmt_str += '    Root Location: {}\n'.format(self.root)
#         tmp = '    Transforms (if any): '
#         fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
#         tmp = '    Target Transforms (if any): '
#         fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
#         return fmt_str





In [94]:
# cd ..

In [95]:
# import torch
# import torchvision
# from torch.utils.data import Dataset, DataLoader
# import numpy as np
# import math

root_dir = 'JoCoR_bach/Data/train_image/'
csv_file = 'Clean_train_data_encd.csv'



import torch
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torchvision.transforms as transforms  # Transformations we can perform on our dataset
import torchvision
import os
import pandas as pd
from skimage import io
from torch.utils.data import Dataset,DataLoader  # Gives easier dataset managment and creates mini batches
from torchvision.transforms.functional import InterpolationMode
from skimage import exposure, img_as_ubyte

# class HAM10000(Dataset):
#     def __init__(self, csv_file="/home/subrat/JoCoR-env/Noisy_final_encoded", transform=None):
#         self.csv_data = pd.read_csv(csv_file)
#         self.root_dir = root_dir
#         self.transform = transform

#     def __len__(self):
#         return len(self.csv_data)

#     def __getitem__(self, index):
#         label=self.csv_data.loc[index, 'dx']
#         img_path = self.csv_data.loc[index, 'image_id']
#         #img_path = os.path.join(root_dir, (annotations.iloc[index, 1] + '.jpg'))
#         image = io.imread(img_path)
# #         y_label = torch.tensor(int(self.annotations.iloc[index, 2]))

#         if self.transform:
#             image = self.transform(image)

#         return (image, label)

In [96]:
class BACH(Dataset):
    def __init__(self, csv_file="/home/subrat/JoCoR_bach/Data/Clean_train_data_encd.csv", transform=None):
        self.csv_data = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.csv_data)

    def __getitem__(self, index):
        label=self.csv_data.loc[index, 'label']
        img_path = self.csv_data.loc[index, 'Name']
#         print(img_path)
        #img_path = os.path.join(root_dir, (annotations.iloc[index, 1] + '.jpg'))
        img_path = os.path.join(root_dir,img_path)
        #print(img_path)
        image = io.imread(img_path)
        #print(image.shape)

        image = img_as_ubyte(exposure.rescale_intensity(image))
        plt.imshow(image, cmap=None)
#         transform = transforms.Resize((300,350))
#         resized_img = transform(image)
#         print(resized_img.shape)
#         y_label = torch.tensor(int(self.annotations.iloc[index, 2]))

        if self.transform:
            image = self.transform(image)

        return (image, label)


In [97]:
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import torch
import codecs
from JoCoR.data.utils import noisify


class HAM10000(data.Dataset):
#     """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

#     Args:
#         root (string): Root directory of dataset where ``processed/training.pt``
#             and  ``processed/test.pt`` exist.
#         train (bool, optional): If True, creates dataset from ``training.pt``,
#             otherwise from ``test.pt``.
#         download (bool, optional): If true, downloads the dataset from the internet and
#             puts it in root directory. If dataset is already downloaded, it is not
#             downloaded again.
#         transform (callable, optional): A function/transform that  takes in an PIL image
#             and returns a transformed version. E.g, ``transforms.RandomCrop``
#         target_transform (callable, optional): A function/transform that takes in the
#             target and transforms it.
#     """
#     urls = [
#         'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
#         'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
#         'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
#         'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
#     ]
#     raw_folder = 'raw'
#     processed_folder = 'processed'
#     training_file = 'training.pt'
#     test_file = 'test.pt'

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False,
                 noise_type=None, noise_rate=0.2, random_state=0):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        self.dataset='/home/subrat/JoCoR-env/archive/HAM10000_images_part_1/'
        self.noise_type=noise_type

#         if download:
#             self.download()

#         if not self._check_exists():
#             raise RuntimeError('Dataset not found.' +
#                                ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))

            if noise_type != 'clean':
                self.train_labels=np.asarray([[self.train_labels[i]] for i in range(len(self.train_labels))])
                self.train_noisy_labels, self.actual_noise_rate = noisify(dataset=self.dataset, train_labels=self.train_labels, noise_type=noise_type, noise_rate=noise_rate, random_state=random_state)
                self.train_noisy_labels=[i[0] for i in self.train_noisy_labels]
                _train_labels=[i[0] for i in self.train_labels]
                self.noise_or_not = np.transpose(self.train_noisy_labels)==np.transpose(_train_labels)
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            #if self.noise_type is not None:
            if self.noise_type != 'clean':
                img, target = self.train_data[index], self.train_noisy_labels[index]
            else:
                img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
            os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))

#     def download(self):
#         """Download the MNIST data if it doesn't exist in processed_folder already."""
#         from six.moves import urllib
#         import gzip

#         if self._check_exists():
#             return

        # download files
#         try:
#             os.makedirs(os.path.join(self.root, self.raw_folder))
#             os.makedirs(os.path.join(self.root, self.processed_folder))
#         except OSError as e:
#             if e.errno == errno.EEXIST:
#                 pass
#             else:
#                 raise

#         for url in self.urls:
#             print('Downloading ' + url)
#             data = urllib.request.urlopen(url)
#             filename = url.rpartition('/')[2]
#             file_path = os.path.join(self.root, self.raw_folder, filename)
#             with open(file_path, 'wb') as f:
#                 f.write(data.read())
#             with open(file_path.replace('.gz', ''), 'wb') as out_f, \
#                     gzip.GzipFile(file_path) as zip_f:
#                 out_f.write(zip_f.read())
#             os.unlink(file_path)

        # process and save as torch files
        print('Processing...')

        training_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, '/home/subrat/JoCoR-env/archive/HAM10000_images_part_1/')),
            read_label_file(os.path.join(self.root, self.raw_folder, '/home/subrat/JoCoR-env/archive/HAM10000_images_part_1/'))
        )
        test_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, '/home/subrat/JoCoR-env/archive/HAM10000_images_part_2/')),
            read_label_file(os.path.join(self.root, self.raw_folder, '/home/subrat/JoCoR-env/archive/HAM10000_images_part_2/'))
        )
        with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        print('Done!')

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str


def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)


def read_label_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
        parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
        return torch.from_numpy(parsed).view(length).long()


def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        images = []
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        return torch.from_numpy(parsed).view(length, num_rows, num_cols)

model

cnn.py

In [98]:
# import math
# import torch
# import torch.nn as nn
# import torch.nn.init as init 
# import torch.nn.functional as F
# import torch.optim as optim

# def call_bn(bn, x):
#     return bn(x)

# class CNN(nn.Module):
#     def __init__(self, input_channel=3, n_outputs=10, dropout_rate=0.25, momentum=0.1):
#         self.dropout_rate = dropout_rate
#         self.momentum = momentum
#         super(CNN, self).__init__()
#         self.c1=nn.Conv2d(input_channel, 64,kernel_size=3,stride=1, padding=1)
#         self.c2=nn.Conv2d(64,64,kernel_size=3,stride=1, padding=1)
#         self.c3=nn.Conv2d(64,128,kernel_size=3,stride=1, padding=1)
#         self.c4=nn.Conv2d(128,128,kernel_size=3,stride=1, padding=1)
#         self.c5=nn.Conv2d(128,196,kernel_size=3,stride=1, padding=1)
#         self.c6=nn.Conv2d(196,16,kernel_size=3,stride=1, padding=1)
#         self.linear1=nn.Linear(256, n_outputs)
#         self.bn1=nn.BatchNorm2d(64, momentum=self.momentum)
#         self.bn2=nn.BatchNorm2d(64, momentum=self.momentum)
#         self.bn3=nn.BatchNorm2d(128, momentum=self.momentum)
#         self.bn4=nn.BatchNorm2d(128, momentum=self.momentum)
#         self.bn5=nn.BatchNorm2d(196, momentum=self.momentum)
#         self.bn6=nn.BatchNorm2d(16, momentum=self.momentum)

#     def forward(self, x,):
#         h=x
#         h=self.c1(h)
#         h=F.relu(call_bn(self.bn1, h))
#         h=self.c2(h)
#         h=F.relu(call_bn(self.bn2, h))
#         h=F.max_pool2d(h, kernel_size=2, stride=2)

#         h=self.c3(h)
#         h=F.relu(call_bn(self.bn3, h))
#         h=self.c4(h)
#         h=F.relu(call_bn(self.bn4, h))
#         h=F.max_pool2d(h, kernel_size=2, stride=2)

#         h=self.c5(h)
#         h=F.relu(call_bn(self.bn5, h))
#         h=self.c6(h)
#         h=F.relu(call_bn(self.bn6, h))
#         h=F.max_pool2d(h, kernel_size=2, stride=2)

#         h = h.view(h.size(0), -1)
#         logit=self.linear1(h)
#         return logit

# class MLPNet(nn.Module):
#     def __init__(self):
#         super(MLPNet, self).__init__()
#         self.fc1 = nn.Linear(28 * 28, 256)
#         self.fc2 = nn.Linear(256, 10)

#     def forward(self, x):
#         x = x.view(-1, 28 * 28)
#         x = F.relu(self.fc1(x))
#         x = self.fc2(x)
#         return x



In [99]:
# import math
# import torch
# import torch.nn as nn
# import torch.nn.init as init 
# import torch.nn.functional as F
# import torch.optim as optim

# class MLPnet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.layers = nn.Sequential(
#         nn.Flatten(),
#         nn.Linear(450 * 600 * 3, 64),
#         nn.ReLU(),
#         nn.Linear(64, 5)
#         )

#     def forward(self, x):
#         #x = x.view(16, 3 * 450 * 600)
#         #x = F.relu(self.fc1(x))
#         #x = self.fc2(x)
#         return self.layers(x)





# # Set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Hyperparameters
# in_channel = 3
# num_classes = 5
# learning_rate = 1e-3
# batch_size = 16
# num_epochs = 1

# # Load Data
# transform=transforms.ToTensor()
# dataset = HAM10000(transform=transform)

# # Dataset is actually a lot larger ~25k images, just took out 10 pictures
# # to upload to Github. It's enough to understand the structure and scale
# # if you got more images.

# train_size = int(0.8 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# # train_set, test_set = torch.utils.data.random_split(dataset, train_set[0],test_set[0])
# train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
# # print(train_loader.label)
# print(f'No of batch loaded for training: {len(train_loader)}')
# # Model
# model = torchvision.models.googlenet(pretrained=True)
# model.to(device)

# # Loss and optimizer
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# # Train Network
# for epoch in range(num_epochs):
#     losses = []

#     for batch_idx, batch in enumerate(train_loader):
#         # Get data to cuda if possible
# #         print(batch_idx, batch[0].size(),
# #           batch[1].size())
# #         print(f'data is :', data)

#         data = batch[0]
#         targets = batch[1]
#         data = data.to(device=device)
#         targets = targets.to(device=device)
# #         
#         # forward
#         output = model(data)
#         loss = criterion(output, targets)

#         losses.append(loss.item())

#         # backward
#         optimizer.zero_grad()
#         loss.backward()

#         # gradient descent or adam step
#         optimizer.step()

#     print(f"Cost at epoch {epoch} is {sum(losses)/len(losses)}")

# # Check accuracy on training to see how good our model is
# def check_accuracy(loader, model):
#     num_correct = 0
#     num_incorrect = 0
#     num_samples = 0
#     model.eval()
    

#     with torch.no_grad():
#         for x, yT, yF in loader:
#             x = x.to(device=device)
#             yT = yT.to(device=device)
#             yF = yF.to(device=device)
#             #print(yT.numpy()) 
           

#             scores = model(x)
#             _, predictions = scores.max(1)
#             num_correct += (predictions == yT).sum()
#             num_incorrect += (predictions == yF).sum()
#             num_samples += predictions.size(0)

#         print(
#              f"Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}"
#         )

#         model.train()


# print("Checking accuracy on Training Set")
# check_accuracy(train_loader, model)

# print("Checking accuracy on Test Set")
# check_accuracy(test_loader, model)

main.**py**

In [102]:
# -*- coding:utf-8 -*-
import os
import torch
import torchvision.transforms as transforms
# from JoCoR.data.cifar import CIFAR10, CIFAR100
# from JoCoR.mnist_HAM import HAM10000
import argparse, sys
import datetime





parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--result_dir', type=str, help='dir to save result txt files', default='results')
parser.add_argument('--noise_rate', type=float, help='corruption rate, should be less than 1', default=0.2)
parser.add_argument('--forget_rate', type=float, help='forget rate', default=None)
parser.add_argument('--noise_type', type=str, help='[pairflip, symmetric]', default='pairflip')
parser.add_argument('--num_gradual', type=int, default=2,
                    help='how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.')
parser.add_argument('--exponent', type=float, default=1,
                    help='exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.')
parser.add_argument('--dataset', type=str, help='mnist, cifar10, or cifar100', default='bach')
parser.add_argument('--n_epoch', type=int, default=200)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--print_freq', type=int, default=50)
parser.add_argument('--num_workers', type=int, default=4, help='how many subprocesses to use for data loading')
parser.add_argument('--num_iter_per_epoch', type=int, default=400)
parser.add_argument('--epoch_decay_start', type=int, default=80)
parser.add_argument('--gpu', type=int, default=None)
parser.add_argument('--co_lambda', type=float, default=0.1)
parser.add_argument('--adjust_lr', type=int, default=1)
parser.add_argument('--model_type', type=str, help='[mlp,cnn]', default='cnn')
parser.add_argument('--save_model', type=str, help='save model?', default="False")
parser.add_argument('--save_result', type=str, help='save result?', default="True")



args = parser.parse_args(args=[])

# Seed
torch.manual_seed(args.seed)
if args.gpu is not None:
    device = torch.device('cuda:{}'.format(args.gpu))
    torch.cuda.manual_seed(args.seed)

else:
    device = torch.device('cpu')
    torch.manual_seed(args.seed)

# Hyper Parameters
batch_size = 5
learning_rate = args.lr

# load dataset
# if args.dataset == 'HAM10000':
#     input_channel = 3
#     num_classes = 7
#     init_epoch = 10
#     filter_outlier = True
#     args.epoch_decay_start = 40
#     args.model_type = "mlp"
#     args.n_epoch = 200
#     dataset = HAM10000(root='./home/subrat/JoCoR-env/archive/HAM10000_images_part_1/',
#                           train=True,
#                           transform=transforms.ToTensor(),
#                           noise_type=args.noise_type,
#                           noise_rate=args.noise_rate
#                           )
#     print(dataset)
#     test_dataset = HAM10000(root='/home/subrat/JoCoR-env/archive/HAM10000_images_part_1/',
#                          train=False,
#                          transform=transforms.ToTensor(),
#                          noise_type=args.noise_type,
#                          noise_rate=args.noise_rate
#                          )

# if args.dataset == 'cifar10':
#     input_channel = 3
#     num_classes = 10
#     init_epoch = 20
#     args.epoch_decay_start = 80
#     filter_outlier = True
#     args.model_type = "cnn"
#     # args.n_epoch = 200
#     train_dataset = CIFAR10(root='./data/',
#                             download=True,
#                             train=True,
#                             transform=transforms.ToTensor(),
#                             noise_type=args.noise_type,
#                             noise_rate=args.noise_rate
#                             )

#     test_dataset = CIFAR10(root='./data/',
#                            download=True,
#                            train=False,
#                            transform=transforms.ToTensor(),
#                            noise_type=args.noise_type,
#                            noise_rate=args.noise_rate
#                            )

# if args.dataset == 'cifar100':
#     input_channel = 3
#     num_classes = 100
#     init_epoch = 5
#     args.epoch_decay_start = 100
#     # args.n_epoch = 200
#     filter_outlier = False
#     args.model_type = "cnn"


#     train_dataset = CIFAR100(root='./data/',
#                              download=True,
#                              train=True,
#                              transform=transforms.ToTensor(),
#                              noise_type=args.noise_type,
#                              noise_rate=args.noise_rate
#                              )

#     test_dataset = CIFAR100(root='./data/',
#                             download=True,
#                             train=False,
#                             transform=transforms.ToTensor(),
#                             noise_type=args.noise_type,
#                             noise_rate=args.noise_rate
#                             )

if args.forget_rate is None:
    forget_rate = args.noise_rate
else:
    forget_rate = args.forget_rate

    
# input_channel = 3
# num_classes = 7
# init_epoch = 10
# filter_outlier = True
# args.epoch_decay_start = 40
# args.model_type = "mlp"
# args.n_epoch = 200
# dataset = HAM10000(root='./home/subrat/JoCoR-env/archive/HAM10000_images_part_1/',
#                       train=True,
#                       transform=transforms.ToTensor(),
#                       noise_type=args.noise_type,
#                       noise_rate=args.noise_rate
#                       )
# Load Data
# transform for rectangular resize

transform=transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(512, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
# img = Image.fromarray(np.astype(np.uint8))


# load dataset
if args.dataset == 'bach':
    input_channel = 3
    num_classes = 4
    init_epoch = 10
    filter_outlier = True
    args.epoch_decay_start = 40
    args.model_type = "mlp"
    args.n_epoch = 2



dataset = BACH(transform=transform)


print(dataset)

batch_size = 10   
train_percentage = 0.4
val_percentage = 0.3
train_size = int(train_percentage * len(dataset))
val_size = int(val_percentage * len(dataset))
test_size = len(dataset) - train_size - val_size


def main():
    # Data Loader (Input Pipeline)
#     print('loading dataset...')
    

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
    # train_set, test_set = torch.utils.data.random_split(dataset, train_set[0],test_set[0])
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=2, shuffle=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=2, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2)
#     train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
#                                                batch_size=batch_size,
#                                                num_workers=args.num_workers,
#                                                drop_last=True,
#                                                shuffle=True)

#     test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                               batch_size=batch_size,
#                                               num_workers=args.num_workers,
#                                               drop_last=True,
#                                               shuffle=False)
    # Define models
    print('building model...')

    model = JoCoR(args, train_dataset, device, input_channel, num_classes)
    print(model)

    epoch = 0
    train_acc1 = 0
    train_acc2 = 0

    # evaluate models with random weights
    test_acc1, test_acc2 = model.evaluate(test_loader)

    print(
        'Epoch [%d/%d] Test Accuracy on the %s test images: Model1 %.4f %% Model2 %.4f ' % (
            epoch + 1, args.n_epoch, len(test_dataset), test_acc1, test_acc2))


    acc_list = []
    # training
    for epoch in range(1, args.n_epoch):
        # train models
        train_acc1, train_acc2, pure_ratio_1_list, pure_ratio_2_list = model.train(train_loader, epoch)

        # evaluate models
        test_acc1, test_acc2, = model.evaluate(test_loader)

        # save results
        if pure_ratio_1_list is None or len(pure_ratio_1_list) == 0:
            print(
                'Epoch [%d/%d] Test Accuracy on the %s test images: Model1 %.4f %% Model2 %.4f' % (
                    epoch + 1, args.n_epoch, len(test_dataset), test_acc1, test_acc2))
        else:
            # save results
            mean_pure_ratio1 = sum(pure_ratio_1_list) / len(pure_ratio_1_list)
            mean_pure_ratio2 = sum(pure_ratio_2_list) / len(pure_ratio_2_list)
            print(
                'Epoch [%d/%d] Test Accuracy on the %s test images: Model1 %.4f %% Model2 %.4f %%, Pure Ratio 1 %.4f %%, Pure Ratio 2 %.4f %%' % (
                    epoch + 1, args.n_epoch, len(test_dataset), test_acc1, test_acc2, mean_pure_ratio1,
                    mean_pure_ratio2))


        if epoch >= 190:
            acc_list.extend([test_acc1,test_acc2])

    avg_acc = sum(acc_list)/len(acc_list)
    print(len(acc_list))
    print("the average acc in last 10 epochs: {}".format(str(avg_acc)))




<__main__.BACH object at 0x7f44d10a5b80>


In [103]:
if __name__ == '__main__':
    main()

building model...
<bound method Module.parameters of MLPnet(
  (layers): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=810000, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=5, bias=True)
  )
)>
<bound method Module.parameters of MLPnet(
  (layers): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=810000, out_features=64, bias=True)
    (2): ReLU()
    (3): Linear(in_features=64, out_features=5, bias=True)
  )
)>
<__main__.JoCoR object at 0x7f44d10a2160>
Evaluating ...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x1047552 and 810000x64)

In [None]:
cd /content/drive/MyDrive/JoCoR/data

In [None]:
!python main.py --dataset cifar10 --noise_type symmetric --noise_rate 0.5 

In [None]:
!python main.py --dataset cifar10 --noise_type symmetric --noise_rate 0.5 --co_lambda 0.9