Ufuk Altun

using Pkg; pkg"add CUDA Knet IterTools FileIO JLD2 Images Plots PyCall"

In [1]:
using PyCall
@pyimport torch

## Use Pycall to imitate dataloading of the original code

['A', 'C', 'P', 'S']
['C', 'L', 'S', 'V']

In [17]:
using PyCall

py"""
import argparse
import os
import random
import torch
import numpy as np
from torchvision import transforms
from PIL import Image, ImageFile
#import datasets

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default="/Users/ufukaltun/PycharmProjects/DGVGS/DGvGS-main/data/pre")
parser.add_argument('--dataset', type=str, default='VLCS')
parser.add_argument('--method', type=str, default='deep-all')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--iterations', type=int, default=600)
parser.add_argument('--val_every', type=int, default=20)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--weight_decay', type=float, default=5e-5)
parser.add_argument('--test_dom_idx', type=int, default=0)
parser.add_argument('--output_dir', type=str, default="result/train")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'


ImageFile.LOAD_TRUNCATED_IMAGES = True


def get_domains(dataset_name):
    return globals()[dataset_name].DOMAINS


class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, file_path, image_dir, transform=None):
        self.file_path = file_path
        self.image_dir = image_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self._read_file()

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = os.path.join(self.image_dir, self.image_paths[idx])

        with open(path, 'rb') as f:
            image = Image.open(f).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def _read_file(self):
        with open(self.file_path) as f:
            for line in f:
                path, label = line.strip().split(',')
                self.image_paths.append(path)
                self.labels.append(int(label) - 1)

class MultiDomainDataset:

    def __init__(self, root_dir, test_dom_idx):
        images_dir = os.path.join(root_dir, 'images')
        split_dir = os.path.join(root_dir, 'split')

        domains = [f.name for f in os.scandir(images_dir) if f.is_dir()]
        domains.sort()

        test_dom = domains[test_dom_idx]
        train_doms = [d for d in domains if d != test_dom]

        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])])

        train_datasets, val_datasets = [], []
        for dom_name in train_doms:
            train_datasets.append(ImageDataset(
                os.path.join(split_dir, dom_name + '_train.txt'),
                images_dir,
                transform))
            val_datasets.append(ImageDataset(
                os.path.join(split_dir, dom_name + '_val.txt'),
                images_dir,
                transform))

        self.datasets = {}
        self.datasets['train'] = train_datasets
        self.datasets['val'] = torch.utils.data.ConcatDataset(val_datasets)
        self.datasets['test'] = ImageDataset(
            os.path.join(split_dir, test_dom + '_test.txt'),
            images_dir,
            transform)

    def __getitem__(self, phase):
        if phase in ['train', 'val', 'test']:
            return self.datasets[phase]
        else:
            raise ValueError


class VLCS(MultiDomainDataset):

    N_CLASSES = 5
    DOMAINS = ['C', 'L', 'S', 'V']

    def __init__(self, root_dir, test_dom_idx):
        self.root_dir = os.path.join(root_dir, 'VLCS/')
        super().__init__(self.root_dir, test_dom_idx)


class PACS(MultiDomainDataset):

    N_CLASSES = 7
    DOMAINS = ['A', 'C', 'P', 'S']

    def __init__(self, root_dir, test_dom_idx):
        self.root_dir = os.path.join(root_dir, 'PACS/')
        super().__init__(self.root_dir, test_dom_idx)

############# data loader ###########

class InfiniteDataLoader(torch.utils.data.DataLoader):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dataset_iterator = super().__iter__()

    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = next(self.dataset_iterator)
        except StopIteration:
            self.dataset_iterator = super().__iter__()
            batch = next(self.dataset_iterator)
        return batch





############# Model ##########

def get_model(device, dataset, args):
    if args.method == 'deep-all':
        return ModelDA(device, dataset, args)
    elif args.method in ['agr-sum', 'agr-rand', 'pcgrad']:
        return ModelGS(device, dataset, args)
    else:
        raise ValueError


class ModelDA:

    def __init__(self, device, dataset, args):
        self.device = device
        self.args = args
        self._create_dataloaders(dataset, args)

    def _create_dataloaders(self, dataset, args):

        def get_dataloader(dataset, batch_size, is_train=False):
            if is_train:
                return InfiniteDataLoader(
                    dataset=dataset,
                    batch_size=batch_size,
                    shuffle=True,
                    drop_last=True)
            else:
                return torch.utils.data.DataLoader(
                    dataset=dataset,
                    batch_size=batch_size,
                    shuffle=False,
                    drop_last=False)

        self.train_loaders = []
        for dom_dataset in dataset['train']:
            self.train_loaders.append(get_dataloader(dom_dataset, args.batch_size, True))
        self.val_loader = get_dataloader(dataset['val'], args.batch_size)
        self.test_loader = get_dataloader(dataset['test'], args.batch_size)

    def _prepare_batch(self, batch):
        inputs, targets = batch
        inputs = inputs.to(self.device)
        targets = targets.to(self.device)
        return inputs, targets

    def train(self):
        train_iterator = zip(*self.train_loaders)
        val_iterator = self.val_loader
        test_iterator = self.test_loader
        
        iterations = self.args.iterations
        val_every = self.args.val_every

        return train_iterator, val_iterator, test_iterator
        #for it in range(iterations):
        # Training
        #    train_batches = [self._prepare_batch(batch) for batch in next(train_iterator)]
        #    train_loss, train_acc = self._train_step(train_batches)
        #    is_train = True
        #    inputs = torch.cat([x for x, _ in train_batches])
        #    targets = torch.cat([y for _, y in train_batches])



###################

#dataset = PACS(args.data_dir, args.test_dom_idx)
dataset = VLCS(args.data_dir, args.test_dom_idx)
#dataset = vars(datasets)[args.dataset](args.data_dir, args.test_dom_idx)
#dataset = PACS(args.data_dir, args.test_dom_idx)
model = ModelDA(device, dataset, args)
train_iterator, val_iterator, test_iterator = model.train()

def prep_batch(batch):
    inputs, targets = batch
    inputs = inputs.to('cpu')
    targets = targets.to('cpu')
    return inputs, targets
    

############## def return data #######
def return_data(train_iterator):
    train_batches = [prep_batch(batch) for batch in next(train_iterator)]
    inputs = torch.cat([x for x, _ in train_batches])
    targets = torch.cat([y for _, y in train_batches])
    return inputs, targets

def return_val_data(val_iterator):
    train_batches = [prep_batch(batch) for batch in val_iterator]
    inputs = torch.cat([x for x, _ in train_batches])
    targets = torch.cat([y for _, y in train_batches])
    return inputs, targets
    
def return_data_gs(train_iterator):
    train_batches = [prep_batch(batch) for batch in next(train_iterator)]
    return train_batches
    
"""   

In [3]:
include("models.jl");
include("training.jl");

#### Generate an alexnet model

In [18]:
model = generate_alexnet_model( 5 ; pretrained = true);

#### Set optimizers of each parameter to Adam. Parameter sizes should be:
##### Original Alexnet
- (11, 11, 3, 64)
- (1, 1, 64, 1)
- (5, 5, 64, 192)
- (1, 1, 192, 1)
- (3, 3, 192, 384)
- (1, 1, 384, 1)
- (3, 3, 384, 256)
- (1, 1, 256, 1)
- (3, 3, 256, 256)
- (1, 1, 256, 1)
- (4096, 9216)
- (4096,)
- (4096, 4096)
- (4096,)
- (7, 4096)
- (7,)

In [19]:
for param in params(model)
    param.opt = Adam(; lr=1e-5)
    println(size(param))
end

(11, 11, 3, 64)
(1, 1, 64, 1)
(5, 5, 64, 192)
(1, 1, 192, 1)
(3, 3, 192, 384)
(1, 1, 384, 1)
(3, 3, 384, 256)
(1, 1, 256, 1)
(3, 3, 256, 256)
(1, 1, 256, 1)
(4096, 9216)
(4096,)
(4096, 4096)
(4096,)
(5, 4096)
(5,)


#### Define batchsize, model path and target domain  

In [20]:
num_iter=600
check_freq=20
batchsize = 128
atype = Array{Float32}
data_path = "/Users/ufukaltun/Documents/koç/dersler/ku deep learning/project/data";
target = "Caltech101"

"Caltech101"

In [21]:
inputs_val, targets_val = py"return_val_data(val_iterator)"
inputs_val = permutedims(inputs_val.detach().numpy(), (4, 3, 2, 1));
targets_val = targets_val.detach().numpy();

In [22]:
inputs_test, targets_test = py"return_val_data(test_iterator)"
inputs_test = permutedims(inputs_test.detach().numpy(), (4, 3, 2, 1));
targets_test = targets_test.detach().numpy();

## Training of Baseline model with "Cartoon" target, PACS dataset

In [9]:
n=1
results=[]
best_acc=-1
for it in progress(1:num_iter)
    inputs_train, targets_train = py"return_data(train_iterator)"
    inputs_train = permutedims(inputs_train.detach().numpy(), (4, 3, 2, 1));
    targets_train = targets_train.detach().numpy();

    D = @diff model(inputs_train, targets_train);
    for w in Knet.params(model)
        g = grad(D, w)
        update!(w, g)
    end
    if n==1 
        println("First batch is trained successfully")
        best_acc = accuracy(model(inputs_val),targets_val)
    end
    if n%check_freq == 0
        acc = accuracy(model(inputs_val),targets_val)
        if acc > best_acc
            best_acc = acc
            save(string("best_model_b_debugging_target_", target,".jld2"),"weights",model)
        end
        result=(model(inputs_train, targets_train), model(inputs_val, targets_val), 
                    accuracy(model(inputs_train),targets_train), acc)
        println("Train loss: ",result[1],"  Val loss: ",result[2],
                "  Train acc: ",result[3],"  Val acc: ",result[4])
        push!(results,result)
    end

    (n += 1) > num_iter && break
end 
println("Training ended successfully, saving the results")
save(string("results_b_debugging_target_", target,".jld2"),"results",results)
best_model = load(string("best_model_b_debugging_target_", target,".jld2"))["weights"]
tst_loss = best_model(inputs_test, targets_test)
tst_acc = accuracy(model(inputs_test),targets_test)
println("Test loss: ", tst_loss, "  Test acc: ", tst_acc)

┣                    ┫ [0.17%, 1/600, 00:00/00:01, 995.64i/s] 

First batch is trained successfully


┣▋                   ┫ [3.33%, 20/600, 12:01/06:00:43, 34.64s/i] 

Train loss: 1.4085758

┣▋                   ┫ [3.50%, 21/600, 13:51/06:35:45, 109.64s/i] 

  Val loss: 1.5214493  Train acc: 0.5214723926380368  Val acc: 0.44849445324881143


┣█▎                  ┫ [6.83%, 41/600, 27:07/06:36:43, 100.03s/i] 

Train loss: 1.0361052  Val loss: 1.1410123  Train acc: 0.6276923076923077  Val acc: 0.5832012678288431


┣██                  ┫ [10.17%, 61/600, 40:22/06:37:01, 101.38s/i] 

Train loss: 0.6770923  Val loss: 0.90846246  Train acc: 0.7694805194805194  Val acc: 0.6576862123613312


┣██▋                 ┫ [13.50%, 81/600, 53:52/06:39:03, 108.77s/i] 

Train loss: 0.63129944  Val loss: 0.74905396  Train acc: 0.7641509433962265  Val acc: 0.7337559429477021


┣███▎                ┫ [16.83%, 101/600, 01:07:18/06:39:46, 102.69s/i] 

Train loss: 0.5136729  Val loss: 0.64311653  Train acc: 0.8170347003154574  Val acc: 0.7638668779714739


┣████                ┫ [20.17%, 121/600, 01:20:56/06:41:20, 101.50s/i] 

Train loss: 0.37836137  Val loss: 0.56522363  Train acc: 0.8475609756097561  Val acc: 0.7892234548335975


┣████▋               ┫ [23.50%, 141/600, 01:34:33/06:42:18, 100.09s/i] 

Train loss: 0.30146593  Val loss: 0.51837486  Train acc: 0.8987341772151899  Val acc: 0.8066561014263075


┣█████▎              ┫ [26.83%, 161/600, 01:48:25/06:44:02, 103.13s/i] 

Train loss: 0.31053635  Val loss: 0.48111856  Train acc: 0.8885448916408669  Val acc: 0.8256735340729001


┣██████              ┫ [30.17%, 181/600, 02:01:33/06:42:55, 103.84s/i] 

Train loss: 0.20385881  Val loss: 0.45375693  Train acc: 0.9341692789968652  Val acc: 0.8367670364500792


┣██████▋             ┫ [33.50%, 201/600, 02:14:38/06:41:54, 101.26s/i] 

Train loss: 0.2125772  Val loss: 0.42791003  Train acc: 0.9327217125382263  Val acc: 0.8431061806656102


┣███████▎            ┫ [36.83%, 221/600, 02:27:35/06:40:42, 99.63s/i] 

Train loss: 0.16916327  Val loss: 0.4108  Train acc: 0.9569230769230769  Val acc: 0.8573692551505546


┣████████            ┫ [40.17%, 241/600, 02:40:36/06:39:50, 100.73s/i] 

Train loss: 0.13714819  Val loss: 0.39227238  Train acc: 0.9501557632398754  Val acc: 0.8557844690966719


┣████████▋           ┫ [43.50%, 261/600, 02:53:41/06:39:16, 102.41s/i] 

Train loss: 0.11283547  Val loss: 0.381017  Train acc: 0.9622641509433962  Val acc: 0.8652931854199684


┣█████████▎          ┫ [46.83%, 281/600, 03:06:48/06:38:52, 100.00s/i] 

Train loss: 0.11470907  Val loss: 0.37611946  Train acc: 0.9551282051282052  Val acc: 0.8652931854199684


┣██████████          ┫ [50.17%, 301/600, 03:20:20/06:39:21, 102.44s/i] 

Train loss: 0.088965766  Val loss: 0.3620223  Train acc: 0.9747634069400631  Val acc: 0.8716323296354992


┣██████████▋         ┫ [53.50%, 321/600, 03:33:27/06:38:58, 103.21s/i] 

Train loss: 0.08686784  Val loss: 0.36093095  Train acc: 0.9693251533742331  Val acc: 0.8763866877971473


┣███████████▎        ┫ [56.83%, 341/600, 03:46:35/06:38:41, 102.60s/i] 

Train loss: 0.05186857  Val loss: 0.34806526  Train acc: 0.99079754601227  Val acc: 0.884310618066561


┣████████████        ┫ [60.17%, 361/600, 03:59:45/06:38:29, 101.11s/i] 

Train loss: 0.044794194  Val loss: 0.3440363  Train acc: 0.9876543209876543  Val acc: 0.884310618066561


┣████████████▋       ┫ [63.50%, 381/600, 04:13:04/06:38:31, 103.99s/i] 

Train loss: 0.055252444  Val loss: 0.34776396  Train acc: 0.9819277108433735  Val acc: 0.8874801901743264


┣█████████████▎      ┫ [66.83%, 401/600, 04:26:23/06:38:34, 100.44s/i] 

Train loss: 0.07424606  Val loss: 0.3427334  Train acc: 0.9782608695652174  Val acc: 0.8858954041204438


┣██████████████      ┫ [70.17%, 421/600, 04:39:23/06:38:11, 98.63s/i] 

Train loss: 0.036103826  Val loss: 0.3453525  Train acc: 0.9936507936507937  Val acc: 0.884310618066561


┣██████████████▋     ┫ [73.50%, 441/600, 04:52:23/06:37:48, 100.31s/i] 

Train loss: 0.028682392  Val loss: 0.3345934  Train acc: 0.9938461538461538  Val acc: 0.8827258320126783


┣███████████████▎    ┫ [76.83%, 461/600, 05:05:30/06:37:37, 100.20s/i] 

Train loss: 0.034097034  Val loss: 0.35122663  Train acc: 0.9935275080906149  Val acc: 0.8858954041204438


┣████████████████    ┫ [80.17%, 481/600, 05:18:38/06:37:27, 102.54s/i] 

Train loss: 0.033544566  Val loss: 0.33292723  Train acc: 0.9807692307692307  Val acc: 0.8922345483359746


┣████████████████▋   ┫ [83.50%, 501/600, 05:31:46/06:37:19, 100.75s/i] 

Train loss: 0.024951257  Val loss: 0.33240846  Train acc: 1.0  Val acc: 0.8922345483359746


┣█████████████████▎  ┫ [86.83%, 521/600, 05:44:53/06:37:10, 97.50s/i] 

Train loss: 0.016679421  Val loss: 0.33732545  Train acc: 1.0  Val acc: 0.8922345483359746


┣██████████████████  ┫ [90.17%, 541/600, 05:57:50/06:36:52, 98.13s/i] 

Train loss: 0.010127131  Val loss: 0.33860353  Train acc: 1.0  Val acc: 0.8890649762282092


┣██████████████████▋ ┫ [93.50%, 561/600, 06:10:54/06:36:41, 99.37s/i] 

Train loss: 0.007621357  Val loss: 0.34247118  Train acc: 1.0  Val acc: 0.8922345483359746


┣███████████████████▎┫ [96.83%, 581/600, 06:23:57/06:36:31, 101.88s/i] 

Train loss: 0.012343519  Val loss: 0.33897206  Train acc: 1.0  Val acc: 0.8969889064976229


┣████████████████████┫ [100.00%, 600/600, 06:35:24/06:35:24, 35.26s/i] 

Train loss: 0.013096458  Val loss: 0.34188887  Train acc: 1.0  Val acc: 0.8858954041204438
Training ended successfully, saving the results
Test loss: 1.3054203  Test acc: 0.6675191815856778


## Training of Gradient Surgery model with "Cartoon" target, PACS dataset

In [16]:
n=1
results=[]
best_acc=0
for it in progress(1:num_iter)
    train_batches = py"return_data_gs(train_iterator)"
    inputs1, targets1 = train_batches[1];
    inputs1 = permutedims(inputs1.detach().numpy(), (4, 3, 2, 1));
    targets1 = targets1.detach().numpy();
    
    inputs2, targets2 = train_batches[2];
    inputs2 = permutedims(inputs2.detach().numpy(), (4, 3, 2, 1));
    targets2 = targets2.detach().numpy();
    
    inputs3, targets3 = train_batches[3];
    inputs3 = permutedims(inputs3.detach().numpy(), (4, 3, 2, 1));
    targets3 = targets3.detach().numpy();
    
    D1 = @diff model(inputs1, targets1)
    D2 = @diff model(inputs2, targets2)
    D3 = @diff model(inputs3, targets3)

    for w in Knet.params(model)
        g1 = grad(D1, w)
        g2 = grad(D2, w)
        g3 = grad(D3, w)
        MaskGS = abs.(sign.(g1).+sign.(g2).+sign.(g3)).==3
        g = (g1.+g2.+g3).*MaskGS
        update!(w, g)
    end

    if n==1 
        println("First batch is trained successfully")
        best_acc = accuracy(model(inputs_val),targets_val)
    end

    if n%check_freq == 0
        acc = accuracy(model(inputs_val),targets_val)
        if acc > best_acc
            best_acc = acc
            save(string("best_model_gs_debugging_target_", target,".jld2"),"weights",model)
        end
        temp_loss = (model(inputs1, targets1) + model(inputs2, targets2) + model(inputs3, targets3))/3
        temp_acc = (accuracy(model(inputs1),targets1) + accuracy(model(inputs2),targets2) + accuracy(model(inputs3),targets3))/3
        result=(temp_loss, model(inputs_val, targets_val), temp_acc, acc)
        println("Train loss: ",result[1],"  Val loss: ",result[2],
                "  Train acc: ",result[3],"  Val acc: ",result[4])
        push!(results,result)
    end

    (n += 1) > num_iter && break
end 
println("Training ended successfully, saving the results")
save(string("results_gs_debugging_target_", target,".jld2"),"results",results)
best_model = load(string("best_model_gs_debugging_target_", target,".jld2"))["weights"]
tst_loss = best_model(inputs_test, targets_test)
tst_acc = accuracy(model(inputs_test),targets_test)
println("Test loss: ", tst_loss, "  Test acc: ", tst_acc)

┣                    ┫ [0.17%, 1/600, 00:00/00:00, 1839.36i/s] 

First batch is trained successfully


┣▋                   ┫ [3.33%, 20/600, 11:48/05:53:52, 34.37s/i] 

Train loss: 2.1176128

┣▋                   ┫ [3.50%, 21/600, 13:27/06:24:26, 99.56s/i] 

  Val loss: 2.133992  Train acc: 0.27100328645546606  Val acc: 0.2583201267828843


┣█▎                  ┫ [6.83%, 41/600, 26:51/06:32:49, 105.07s/i] 

Train loss: 1.6800386  Val loss: 1.6481905  Train acc: 0.38772794200928756  Val acc: 0.3692551505546751


┣██                  ┫ [10.17%, 61/600, 39:48/06:31:33, 100.15s/i] 

Train loss: 1.2505255  Val loss: 1.3738507  Train acc: 0.5027629233511587  Val acc: 0.48019017432646594


┣██▋                 ┫ [13.50%, 81/600, 53:06/06:33:19, 101.87s/i] 

Train loss: 1.1999702  Val loss: 1.1759523  Train acc: 0.5623292075907362  Val acc: 0.561014263074485


┣███▎                ┫ [16.83%, 101/600, 01:06:25/06:34:31, 103.42s/i] 

Train loss: 1.0379937  Val loss: 1.0238854  Train acc: 0.6066872238271833  Val acc: 0.6212361331220285


┣████                ┫ [20.17%, 121/600, 01:19:54/06:36:13, 101.19s/i] 

Train loss: 0.7890194  Val loss: 0.90811145  Train acc: 0.7098022228526012  Val acc: 0.6719492868462758


┣████▋               ┫ [23.50%, 141/600, 01:32:54/06:35:17, 99.73s/i] 

Train loss: 0.67838126  Val loss: 0.8182379  Train acc: 0.737755381873029  Val acc: 0.7068145800316957


┣█████▎              ┫ [26.83%, 161/600, 01:45:52/06:34:32, 100.04s/i] 

Train loss: 0.70036966  Val loss: 0.74389416  Train acc: 0.7660866910866911  Val acc: 0.7194928684627575


┣██████              ┫ [30.17%, 181/600, 01:59:10/06:35:03, 101.89s/i] 

Train loss: 0.5425684  Val loss: 0.6851195  Train acc: 0.8052278637838145  Val acc: 0.7527733755942948


┣██████▋             ┫ [33.50%, 201/600, 02:12:04/06:34:14, 101.11s/i] 

Train loss: 0.5990903  Val loss: 0.6419718  Train acc: 0.8150118299747083  Val acc: 0.7702060221870047


┣███████▎            ┫ [36.83%, 221/600, 02:25:04/06:33:50, 100.55s/i] 

Train loss: 0.48171905  Val loss: 0.6102722  Train acc: 0.8274124345194087  Val acc: 0.7765451664025357


┣████████            ┫ [40.17%, 241/600, 02:38:31/06:34:40, 101.14s/i] 

Train loss: 0.44776157  Val loss: 0.5823609  Train acc: 0.824278226141942  Val acc: 0.7812995245641838


┣████████▋           ┫ [43.50%, 261/600, 02:51:36/06:34:29, 101.69s/i] 

Train loss: 0.43215403  Val loss: 0.55676997  Train acc: 0.8546257364090057  Val acc: 0.7812995245641838


┣█████████▎          ┫ [46.83%, 281/600, 03:04:40/06:34:18, 102.96s/i] 

Train loss: 0.4051808  Val loss: 0.5389829  Train acc: 0.8652681992337165  Val acc: 0.7955625990491284


┣██████████          ┫ [50.17%, 301/600, 03:17:52/06:34:25, 103.77s/i] 

Train loss: 0.339008  Val loss: 0.51649094  Train acc: 0.8786737142000299  Val acc: 0.8050713153724247


┣██████████▋         ┫ [53.50%, 321/600, 03:31:00/06:34:23, 100.28s/i] 

Train loss: 0.34300157  Val loss: 0.49576774  Train acc: 0.8796899437622708  Val acc: 0.8193343898573693


┣███████████▎        ┫ [56.83%, 341/600, 03:43:55/06:34:00, 100.45s/i] 

Train loss: 0.25773257  Val loss: 0.4840932  Train acc: 0.9156308892367054  Val acc: 0.8225039619651348


┣████████████        ┫ [60.17%, 361/600, 03:56:59/06:33:53, 101.92s/i] 

Train loss: 0.24254006  Val loss: 0.47028902  Train acc: 0.9227599524374734  Val acc: 0.8256735340729001


┣████████████▋       ┫ [63.50%, 381/600, 04:10:10/06:33:58, 101.46s/i] 

Train loss: 0.28257778  Val loss: 0.45632926  Train acc: 0.8959134501258129  Val acc: 0.8367670364500792


┣█████████████▎      ┫ [66.83%, 401/600, 04:23:30/06:34:17, 99.54s/i] 

Train loss: 0.2923469  Val loss: 0.44868937  Train acc: 0.9116621682411156  Val acc: 0.838351822503962


┣██████████████      ┫ [70.17%, 421/600, 04:36:42/06:34:20, 103.94s/i] 

Train loss: 0.28189528  Val loss: 0.4471563  Train acc: 0.8982232716396107  Val acc: 0.8335974643423137


┣██████████████▋     ┫ [73.50%, 441/600, 04:49:51/06:34:21, 100.21s/i] 

Train loss: 0.21448715  Val loss: 0.4292802  Train acc: 0.9372210876635655  Val acc: 0.8462757527733756


┣███████████████▎    ┫ [76.83%, 461/600, 05:03:03/06:34:26, 102.85s/i] 

Train loss: 0.20756511  Val loss: 0.42194444  Train acc: 0.9267087349467212  Val acc: 0.8510301109350238


┣████████████████    ┫ [80.17%, 481/600, 05:16:10/06:34:23, 101.48s/i] 

Train loss: 0.2550551  Val loss: 0.41701376  Train acc: 0.9049863716530383  Val acc: 0.8462757527733756


┣████████████████▋   ┫ [83.50%, 501/600, 05:29:13/06:34:17, 102.45s/i] 

Train loss: 0.18725419  Val loss: 0.40954953  Train acc: 0.9320718901666325  Val acc: 0.8605388272583201


┣█████████████████▎  ┫ [86.83%, 521/600, 05:42:26/06:34:22, 98.70s/i] 

Train loss: 0.19949584  Val loss: 0.4026128  Train acc: 0.9410641449562281  Val acc: 0.8605388272583201


┣██████████████████  ┫ [90.17%, 541/600, 05:55:41/06:34:29, 101.55s/i] 

Train loss: 0.1554196  Val loss: 0.39998448  Train acc: 0.9493226898887276  Val acc: 0.8621236133122029


┣██████████████████▋ ┫ [93.50%, 561/600, 06:08:48/06:34:26, 102.57s/i] 

Train loss: 0.14438382  Val loss: 0.3924485  Train acc: 0.9553701015965167  Val acc: 0.8652931854199684


┣███████████████████▎┫ [96.83%, 581/600, 06:21:50/06:34:19, 97.60s/i] 

Train loss: 0.12688617  Val loss: 0.3879612  Train acc: 0.9469829598506069  Val acc: 0.8621236133122029


┣████████████████████┫ [100.00%, 600/600, 06:33:24/06:33:24, 37.68s/i] 

Train loss: 0.12101819  Val loss: 0.38308683  Train acc: 0.9436403289558629  Val acc: 0.8763866877971473
Training ended successfully, saving the results
Test loss: 1.0162643  Test acc: 0.7135549872122762


## Training of Baseline model with "Photo" target, PACS dataset

In [27]:
n=1
results=[]
best_acc=-1
for it in progress(1:num_iter)
    inputs_train, targets_train = py"return_data(train_iterator)"
    inputs_train = permutedims(inputs_train.detach().numpy(), (4, 3, 2, 1));
    targets_train = targets_train.detach().numpy();

    D = @diff model(inputs_train, targets_train);
    for w in Knet.params(model)
        g = grad(D, w)
        update!(w, g)
    end
    if n==1 
        println("First batch is trained successfully")
        best_acc = accuracy(model(inputs_val),targets_val)
    end
    if n%check_freq == 0
        acc = accuracy(model(inputs_val),targets_val)
        if acc > best_acc
            best_acc = acc
            save(string("best_model_b_debugging_target_", target,".jld2"),"weights",model)
        end
        result=(model(inputs_train, targets_train), model(inputs_val, targets_val), 
                    accuracy(model(inputs_train),targets_train), acc)
        println("Train loss: ",result[1],"  Val loss: ",result[2],
                "  Train acc: ",result[3],"  Val acc: ",result[4])
        push!(results,result)
    end

    (n += 1) > num_iter && break
end 
println("Training ended successfully, saving the results")
save(string("results_b_debugging_target_", target,".jld2"),"results",results)
best_model = load(string("best_model_b_debugging_target_", target,".jld2"))["weights"]
tst_loss = best_model(inputs_test, targets_test)
tst_acc = accuracy(model(inputs_test),targets_test)
println("Test loss: ", tst_loss, "  Test acc: ", tst_acc)

┣                    ┫ [0.17%, 1/600, 00:00/00:00, 2896.10i/s] 

First batch is trained successfully


┣▋                   ┫ [3.50%, 21/600, 13:49/06:34:54, 104.63s/i] 

Train loss: 1.527172  Val loss: 1.5901024  Train acc: 0.4634920634920635  Val acc: 0.4020618556701031


┣█▎                  ┫ [6.83%, 41/600, 27:01/06:35:15, 104.25s/i] 

Train loss: 1.2078202  Val loss: 1.2394893  Train acc: 0.5844155844155844  Val acc: 0.5493372606774669


┣██                  ┫ [10.17%, 61/600, 40:10/06:35:05, 104.29s/i] 

Train loss: 0.9662096  Val loss: 0.9912996  Train acc: 0.7243589743589743  Val acc: 0.6627393225331369


┣██▋                 ┫ [13.50%, 81/600, 53:16/06:34:36, 103.69s/i] 

Train loss: 0.7603778  Val loss: 0.8062283  Train acc: 0.7436708860759493  Val acc: 0.7187039764359352


┣███▎                ┫ [16.83%, 101/600, 01:06:25/06:34:32, 108.47s/i] 

Train loss: 0.61472404  Val loss: 0.6847767  Train acc: 0.7896440129449838  Val acc: 0.7555228276877761


┣████                ┫ [20.17%, 121/600, 01:20:04/06:37:04, 107.20s/i] 

Train loss: 0.55010855  Val loss: 0.60808086  Train acc: 0.7795527156549521  Val acc: 0.7776141384388807


┣████▋               ┫ [23.50%, 141/600, 01:33:31/06:37:56, 103.94s/i] 

Train loss: 0.30685553  Val loss: 0.5538574  Train acc: 0.9185667752442996  Val acc: 0.801178203240059


┣█████▎              ┫ [26.83%, 161/600, 01:46:53/06:38:18, 108.58s/i] 

Train loss: 0.3633994  Val loss: 0.5187481  Train acc: 0.8793650793650793  Val acc: 0.8144329896907216


┣██████              ┫ [30.17%, 181/600, 02:00:08/06:38:15, 104.08s/i] 

Train loss: 0.2546188  Val loss: 0.48341036  Train acc: 0.9228295819935691  Val acc: 0.8262150220913107


┣██████▋             ┫ [33.50%, 201/600, 02:14:28/06:41:24, 111.28s/i] 

Train loss: 0.17676342  Val loss: 0.4607604  Train acc: 0.9539473684210527  Val acc: 0.8321060382916053


┣███████▎            ┫ [36.83%, 221/600, 02:28:44/06:43:47, 111.40s/i] 

Train loss: 0.20662382  Val loss: 0.44135493  Train acc: 0.930921052631579  Val acc: 0.8365243004418262


┣████████            ┫ [40.17%, 241/600, 02:42:59/06:45:45, 111.53s/i] 

Train loss: 0.17817804  Val loss: 0.42596266  Train acc: 0.9636363636363636  Val acc: 0.8438880706921944


┣████████▋           ┫ [43.50%, 261/600, 02:57:08/06:47:11, 112.61s/i] 

Train loss: 0.12148997  Val loss: 0.41228577  Train acc: 0.9713375796178344  Val acc: 0.845360824742268


┣█████████▎          ┫ [46.83%, 281/600, 03:11:21/06:48:35, 113.79s/i] 

Train loss: 0.0972638  Val loss: 0.40038007  Train acc: 0.9746031746031746  Val acc: 0.8497790868924889


┣██████████          ┫ [50.17%, 301/600, 03:25:44/06:50:06, 123.12s/i] 

Train loss: 0.09496257  Val loss: 0.39308658  Train acc: 0.9713375796178344  Val acc: 0.8586156111929307


┣██████████▋         ┫ [53.50%, 321/600, 03:41:47/06:54:32, 121.80s/i] 

Train loss: 0.097255155  Val loss: 0.38299498  Train acc: 0.9837133550488599  Val acc: 0.865979381443299


┣███████████▎        ┫ [56.83%, 341/600, 03:57:13/06:57:24, 134.32s/i] 

Train loss: 0.05797205  Val loss: 0.3773708  Train acc: 0.990506329113924  Val acc: 0.8703976435935199


┣████████████        ┫ [60.17%, 361/600, 04:11:47/06:58:29, 116.22s/i] 

Train loss: 0.035601567  Val loss: 0.37687027  Train acc: 1.0  Val acc: 0.8718703976435935


┣████████████▋       ┫ [63.50%, 381/600, 04:26:22/06:59:29, 114.88s/i] 

Train loss: 0.058325697  Val loss: 0.37328833  Train acc: 0.9867549668874173  Val acc: 0.8733431516936672


┣█████████████▎      ┫ [66.83%, 401/600, 04:40:44/07:00:03, 115.43s/i] 

Train loss: 0.042593323  Val loss: 0.3651906  Train acc: 0.9902597402597403  Val acc: 0.8733431516936672


┣██████████████      ┫ [70.17%, 421/600, 04:55:25/07:01:01, 114.03s/i] 

Train loss: 0.03785879  Val loss: 0.3701025  Train acc: 0.9936708860759493  Val acc: 0.8777614138438881


┣██████████████▋     ┫ [73.50%, 441/600, 05:09:52/07:01:35, 114.34s/i] 

Train loss: 0.029305415  Val loss: 0.365034  Train acc: 0.9968051118210862  Val acc: 0.8807069219440353


┣███████████████▎    ┫ [76.83%, 461/600, 05:24:31/07:02:21, 111.21s/i] 

Train loss: 0.032589197  Val loss: 0.3685532  Train acc: 0.9937106918238994  Val acc: 0.8807069219440353


┣████████████████    ┫ [80.17%, 481/600, 05:37:57/07:01:34, 104.81s/i] 

Train loss: 0.054419357  Val loss: 0.36570343  Train acc: 0.987220447284345  Val acc: 0.8807069219440353


┣████████████████▋   ┫ [83.50%, 501/600, 05:51:18/07:00:43, 104.41s/i] 

Train loss: 0.017000398  Val loss: 0.36538532  Train acc: 0.996875  Val acc: 0.882179675994109


┣█████████████████▎  ┫ [86.83%, 521/600, 06:04:29/06:59:45, 103.31s/i] 

Train loss: 0.011174919  Val loss: 0.3647383  Train acc: 1.0  Val acc: 0.8807069219440353


┣██████████████████  ┫ [90.17%, 541/600, 06:17:43/06:58:55, 103.77s/i] 

Train loss: 0.0153678795  Val loss: 0.36292788  Train acc: 0.9968553459119497  Val acc: 0.8880706921944035


┣██████████████████▋ ┫ [93.50%, 561/600, 06:31:12/06:58:24, 106.95s/i] 

Train loss: 0.019952085  Val loss: 0.36271635  Train acc: 0.9938461538461538  Val acc: 0.8895434462444771


┣███████████████████▎┫ [96.83%, 581/600, 06:44:55/06:58:10, 105.84s/i] 

Train loss: 0.020592926  Val loss: 0.36113295  Train acc: 0.9968847352024922  Val acc: 0.8865979381443299


┣████████████████████┫ [100.00%, 600/600, 06:56:29/06:56:29, 35.15s/i] 

Train loss: 0.0066857697  Val loss: 0.36587235  Train acc: 1.0  Val acc: 0.8865979381443299
Training ended successfully, saving the results
Test loss: 1.0690082  Test acc: 0.7466216216216216


## Training of Gradient Surgery model with "Photo" target, PACS dataset

In [34]:
n=1
results=[]
best_acc=0
for it in progress(1:num_iter)
    train_batches = py"return_data_gs(train_iterator)"
    inputs1, targets1 = train_batches[1];
    inputs1 = permutedims(inputs1.detach().numpy(), (4, 3, 2, 1));
    targets1 = targets1.detach().numpy();
    
    inputs2, targets2 = train_batches[2];
    inputs2 = permutedims(inputs2.detach().numpy(), (4, 3, 2, 1));
    targets2 = targets2.detach().numpy();
    
    inputs3, targets3 = train_batches[3];
    inputs3 = permutedims(inputs3.detach().numpy(), (4, 3, 2, 1));
    targets3 = targets3.detach().numpy();
    
    D1 = @diff model(inputs1, targets1)
    D2 = @diff model(inputs2, targets2)
    D3 = @diff model(inputs3, targets3)

    for w in Knet.params(model)
        g1 = grad(D1, w)
        g2 = grad(D2, w)
        g3 = grad(D3, w)
        MaskGS = abs.(sign.(g1).+sign.(g2).+sign.(g3)).==3
        g = (g1.+g2.+g3).*MaskGS
        update!(w, g)
    end

    if n==1 
        println("First batch is trained successfully")
        best_acc = accuracy(model(inputs_val),targets_val)
    end

    if n%check_freq == 0
        acc = accuracy(model(inputs_val),targets_val)
        if acc > best_acc
            best_acc = acc
            save(string("best_model_gs_debugging_target_", target,".jld2"),"weights",model)
        end
        temp_loss = (model(inputs1, targets1) + model(inputs2, targets2) + model(inputs3, targets3))/3
        temp_acc = (accuracy(model(inputs1),targets1) + accuracy(model(inputs2),targets2) + accuracy(model(inputs3),targets3))/3
        result=(temp_loss, model(inputs_val, targets_val), temp_acc, acc)
        println("Train loss: ",result[1],"  Val loss: ",result[2],
                "  Train acc: ",result[3],"  Val acc: ",result[4])
        push!(results,result)
    end

    (n += 1) > num_iter && break
end 
println("Training ended successfully, saving the results")
save(string("results_gs_debugging_target_", target,".jld2"),"results",results)
best_model = load(string("best_model_gs_debugging_target_", target,".jld2"))["weights"]
tst_loss = best_model(inputs_test, targets_test)
tst_acc = accuracy(model(inputs_test),targets_test)
println("Test loss: ", tst_loss, "  Test acc: ", tst_acc)

┣                    ┫ [0.17%, 1/600, 00:00/00:00, 3419.77i/s] 

First batch is trained successfully


┣▋                   ┫ [3.50%, 21/600, 14:04/06:41:59, 105.62s/i] 

Train loss: 1.9722558  Val loss: 1.9904436  Train acc: 0.31257433684618147  Val acc: 0.30338733431516934


┣█▎                  ┫ [6.83%, 41/600, 27:41/06:45:00, 105.35s/i] 

Train loss: 1.6355942  Val loss: 1.6037585  Train acc: 0.38981198952072743  Val acc: 0.4005891016200295


┣██                  ┫ [10.17%, 61/600, 41:45/06:50:43, 110.90s/i] 

Train loss: 1.428254  Val loss: 1.3810295  Train acc: 0.502581369248036  Val acc: 0.48748159057437407


┣██▋                 ┫ [13.50%, 81/600, 55:35/06:51:41, 106.69s/i] 

Train loss: 1.2161494  Val loss: 1.1866721  Train acc: 0.5854747907075052  Val acc: 0.5670103092783505


┣███▎                ┫ [16.83%, 101/600, 01:09:18/06:51:40, 106.46s/i] 

Train loss: 1.0432643  Val loss: 1.0253648  Train acc: 0.6422318422318423  Val acc: 0.6200294550810015


┣████                ┫ [20.17%, 121/600, 01:23:02/06:51:42, 105.45s/i] 

Train loss: 0.95774037  Val loss: 0.89609283  Train acc: 0.6549237170596394  Val acc: 0.6642120765832106


┣████▋               ┫ [23.50%, 141/600, 01:36:38/06:51:11, 108.39s/i] 

Train loss: 0.66941786  Val loss: 0.79835045  Train acc: 0.7746369469773725  Val acc: 0.6995581737849779


┣█████▎              ┫ [26.83%, 161/600, 01:50:47/06:52:53, 108.69s/i] 

Train loss: 0.653483  Val loss: 0.72793037  Train acc: 0.7793795093795094  Val acc: 0.7187039764359352


┣██████              ┫ [30.17%, 181/600, 02:04:36/06:53:01, 107.63s/i] 

Train loss: 0.5961634  Val loss: 0.67178595  Train acc: 0.8131793398920691  Val acc: 0.7452135493372607


┣██████▋             ┫ [33.50%, 201/600, 02:18:27/06:53:18, 106.79s/i] 

Train loss: 0.4244356  Val loss: 0.6298921  Train acc: 0.8794711038421238  Val acc: 0.7584683357879234


┣███████▎            ┫ [36.83%, 221/600, 02:32:32/06:54:06, 106.09s/i] 

Train loss: 0.5306441  Val loss: 0.6005522  Train acc: 0.8264591227106051  Val acc: 0.7658321060382917


┣████████            ┫ [40.17%, 241/600, 02:46:40/06:54:56, 103.31s/i] 

Train loss: 0.45009986  Val loss: 0.5730234  Train acc: 0.8425289955576997  Val acc: 0.7893961708394698


┣████████▋           ┫ [43.50%, 261/600, 03:00:28/06:54:52, 108.20s/i] 

Train loss: 0.37870976  Val loss: 0.5484681  Train acc: 0.8723996129656507  Val acc: 0.7938144329896907


┣█████████▎          ┫ [46.83%, 281/600, 03:14:44/06:55:47, 109.39s/i] 

Train loss: 0.34000954  Val loss: 0.52982914  Train acc: 0.8998316498316498  Val acc: 0.801178203240059


┣██████████          ┫ [50.17%, 301/600, 03:29:02/06:56:41, 110.30s/i] 

Train loss: 0.31222025  Val loss: 0.51326793  Train acc: 0.8918210275508849  Val acc: 0.8070692194403535


┣██████████▋         ┫ [53.50%, 321/600, 03:43:18/06:57:23, 113.35s/i] 

Train loss: 0.38649476  Val loss: 0.4991447  Train acc: 0.8725180608675753  Val acc: 0.8144329896907216


┣███████████▎        ┫ [56.83%, 341/600, 03:57:56/06:58:40, 111.33s/i] 

Train loss: 0.27313218  Val loss: 0.48331332  Train acc: 0.9232425917587864  Val acc: 0.8203240058910162


┣████████████        ┫ [60.17%, 361/600, 04:12:36/06:59:50, 129.30s/i] 

Train loss: 0.24026811  Val loss: 0.4714051  Train acc: 0.9175274725274726  Val acc: 0.8247422680412371


┣████████████▋       ┫ [63.50%, 381/600, 04:28:24/07:02:40, 122.11s/i] 

Train loss: 0.29718092  Val loss: 0.46242207  Train acc: 0.9104390857940752  Val acc: 0.8306332842415317


┣█████████████▎      ┫ [66.83%, 401/600, 04:44:12/07:05:14, 119.32s/i] 

Train loss: 0.24419709  Val loss: 0.4520982  Train acc: 0.9087810337810338  Val acc: 0.833578792341679


┣██████████████      ┫ [70.17%, 421/600, 04:58:48/07:05:50, 110.91s/i] 

Train loss: 0.2636282  Val loss: 0.44736344  Train acc: 0.909761163032191  Val acc: 0.833578792341679


┣██████████████▋     ┫ [73.50%, 441/600, 05:13:04/07:05:57, 110.71s/i] 

Train loss: 0.21050085  Val loss: 0.44009012  Train acc: 0.9353615520282187  Val acc: 0.8438880706921944


┣███████████████▎    ┫ [76.83%, 461/600, 05:27:18/07:06:00, 111.18s/i] 

Train loss: 0.22624488  Val loss: 0.4292507  Train acc: 0.9147266313932981  Val acc: 0.8497790868924889


┣████████████████    ┫ [80.17%, 481/600, 05:41:58/07:06:34, 107.70s/i] 

Train loss: 0.2700124  Val loss: 0.42187068  Train acc: 0.9185828484100393  Val acc: 0.8483063328424153


┣████████████████▋   ┫ [83.50%, 501/600, 05:56:15/07:06:39, 107.67s/i] 

Train loss: 0.12935527  Val loss: 0.42072085  Train acc: 0.959119164887996  Val acc: 0.8541973490427098


┣█████████████████▎  ┫ [86.83%, 521/600, 06:10:02/07:06:08, 104.76s/i] 

Train loss: 0.17314486  Val loss: 0.418859  Train acc: 0.9513832957932337  Val acc: 0.8541973490427098


┣██████████████████  ┫ [90.17%, 541/600, 06:23:51/07:05:43, 104.36s/i] 

Train loss: 0.141451  Val loss: 0.41175207  Train acc: 0.9551067040213571  Val acc: 0.8571428571428571


┣██████████████████▋ ┫ [93.50%, 561/600, 06:38:10/07:05:51, 113.95s/i] 

Train loss: 0.1622795  Val loss: 0.41296008  Train acc: 0.9627977492055163  Val acc: 0.8497790868924889


┣███████████████████▎┫ [96.83%, 581/600, 06:52:35/07:06:04, 109.23s/i] 

Train loss: 0.17271666  Val loss: 0.4080725  Train acc: 0.9403947091997179  Val acc: 0.8541973490427098


┣████████████████████┫ [100.00%, 600/600, 07:05:24/07:05:24, 38.81s/i] 

Train loss: 0.09538663  Val loss: 0.4082594  Train acc: 0.9742039891572603  Val acc: 0.8586156111929307
Training ended successfully, saving the results
Test loss: 1.0408199  Test acc: 0.7466216216216216


## Training of Baseline model with "LabelMe" target, VLCS dataset

In [104]:
n=1
results=[]
best_acc=-1
for it in progress(1:num_iter)
    inputs_train, targets_train = py"return_data(train_iterator)"
    inputs_train = permutedims(inputs_train.detach().numpy(), (4, 3, 2, 1));
    targets_train = targets_train.detach().numpy();

    D = @diff model(inputs_train, targets_train);
    for w in Knet.params(model)
        g = grad(D, w)
        update!(w, g)
    end
    if n==1 
        println("First batch is trained successfully")
        best_acc = accuracy(model(inputs_val),targets_val)
    end
    if n%check_freq == 0
        acc = accuracy(model(inputs_val),targets_val)
        if acc > best_acc
            best_acc = acc
            save(string("best_model_b_debugging_target_", target,".jld2"),"weights",model)
        end
        result=(model(inputs_train, targets_train), model(inputs_val, targets_val), 
                    accuracy(model(inputs_train),targets_train), acc)
        println("Train loss: ",result[1],"  Val loss: ",result[2],
                "  Train acc: ",result[3],"  Val acc: ",result[4])
        push!(results,result)
    end

    (n += 1) > num_iter && break
end 
println("Training ended successfully, saving the results")
save(string("results_b_debugging_target_", target,".jld2"),"results",results)
best_model = load(string("best_model_b_debugging_target_", target,".jld2"))["weights"]
tst_loss = best_model(inputs_test, targets_test)
tst_acc = accuracy(model(inputs_test),targets_test)
println("Test loss: ", tst_loss, "  Test acc: ", tst_acc)

┣                    ┫ [0.17%, 1/600, 00:00/00:00, 2159.05i/s] 

First batch is trained successfully


┣▋                   ┫ [3.50%, 21/600, 16:07/07:40:33, 120.40s/i] 

Train loss: 0.83804864  Val loss: 1.1055499  Train acc: 0.6944444444444444  Val acc: 0.58


┣█▎                  ┫ [6.83%, 41/600, 33:24/08:08:45, 124.25s/i] 

Train loss: 0.72225666  Val loss: 0.88412446  Train acc: 0.7277936962750716  Val acc: 0.648


┣██                  ┫ [10.17%, 61/600, 52:05/08:32:21, 128.28s/i] 

Train loss: 0.6719665  Val loss: 0.78634214  Train acc: 0.7293447293447294  Val acc: 0.6813333333333333


┣██▋                 ┫ [13.50%, 81/600, 01:09:16/08:33:03, 123.66s/i] 

Train loss: 0.49204266  Val loss: 0.7389432  Train acc: 0.7971014492753623  Val acc: 0.708


┣███▎                ┫ [16.83%, 101/600, 01:28:04/08:43:11, 123.05s/i] 

Train loss: 0.5001099  Val loss: 0.70877874  Train acc: 0.7906976744186046  Val acc: 0.708


┣████                ┫ [20.17%, 121/600, 01:45:13/08:41:45, 130.20s/i] 

Train loss: 0.47882128  Val loss: 0.6836695  Train acc: 0.8104956268221575  Val acc: 0.7306666666666667


┣████▋               ┫ [23.50%, 141/600, 02:02:19/08:40:29, 128.51s/i] 

Train loss: 0.4438761  Val loss: 0.669263  Train acc: 0.8260869565217391  Val acc: 0.7373333333333333


┣█████▎              ┫ [26.83%, 161/600, 02:19:10/08:38:37, 136.30s/i] 

Train loss: 0.42836812  Val loss: 0.65840375  Train acc: 0.8169014084507042  Val acc: 0.7413333333333333


┣██████              ┫ [30.17%, 181/600, 02:36:36/08:39:06, 125.10s/i] 

Train loss: 0.3968979  Val loss: 0.64787555  Train acc: 0.8647887323943662  Val acc: 0.74


┣██████▋             ┫ [33.50%, 201/600, 02:53:48/08:38:49, 128.93s/i] 

Train loss: 0.4003579  Val loss: 0.6381761  Train acc: 0.861764705882353  Val acc: 0.7466666666666667


┣███████▎            ┫ [36.83%, 221/600, 03:11:15/08:39:14, 128.90s/i] 

Train loss: 0.36852762  Val loss: 0.6326334  Train acc: 0.8636363636363636  Val acc: 0.7413333333333333


┣████████            ┫ [40.17%, 241/600, 03:29:03/08:40:27, 124.19s/i] 

Train loss: 0.3273816  Val loss: 0.6272776  Train acc: 0.8876080691642652  Val acc: 0.7453333333333333


┣████████▋           ┫ [43.50%, 261/600, 03:46:29/08:40:40, 125.73s/i] 

Train loss: 0.3361952  Val loss: 0.62647617  Train acc: 0.8746355685131195  Val acc: 0.7466666666666667


┣█████████▎          ┫ [46.83%, 281/600, 04:02:55/08:38:41, 115.84s/i] 

Train loss: 0.32101232  Val loss: 0.62343454  Train acc: 0.8959537572254336  Val acc: 0.7426666666666667


┣██████████          ┫ [50.17%, 301/600, 04:18:43/08:35:43, 111.53s/i] 

Train loss: 0.31181154  Val loss: 0.62272584  Train acc: 0.8823529411764706  Val acc: 0.74


┣██████████▋         ┫ [53.50%, 321/600, 04:34:21/08:32:48, 114.03s/i] 

Train loss: 0.26967826  Val loss: 0.6209051  Train acc: 0.8985915492957747  Val acc: 0.7493333333333333


┣███████████▎        ┫ [56.83%, 341/600, 04:49:39/08:29:40, 112.29s/i] 

Train loss: 0.26565212  Val loss: 0.6209872  Train acc: 0.9221902017291066  Val acc: 0.7506666666666667


┣████████████        ┫ [60.17%, 361/600, 05:05:06/08:27:06, 114.03s/i] 

Train loss: 0.27465838  Val loss: 0.6202155  Train acc: 0.9037900874635568  Val acc: 0.748


┣████████████▋       ┫ [63.50%, 381/600, 05:20:23/08:24:32, 110.24s/i] 

Train loss: 0.23845221  Val loss: 0.61953026  Train acc: 0.9285714285714286  Val acc: 0.748


┣█████████████▎      ┫ [66.83%, 401/600, 05:35:41/08:22:17, 112.18s/i] 

Train loss: 0.19337833  Val loss: 0.6216427  Train acc: 0.9369627507163324  Val acc: 0.7546666666666667


┣██████████████      ┫ [70.17%, 421/600, 05:51:10/08:20:29, 114.00s/i] 

Train loss: 0.20134294  Val loss: 0.6216548  Train acc: 0.9546742209631728  Val acc: 0.752


┣██████████████▋     ┫ [73.50%, 441/600, 06:06:41/08:18:54, 113.19s/i] 

Train loss: 0.21854742  Val loss: 0.6286312  Train acc: 0.9287749287749287  Val acc: 0.7586666666666667


┣███████████████▎    ┫ [76.83%, 461/600, 06:22:09/08:17:22, 114.13s/i] 

Train loss: 0.17848554  Val loss: 0.6334652  Train acc: 0.9393063583815029  Val acc: 0.756


┣████████████████    ┫ [80.17%, 481/600, 06:37:38/08:16:01, 111.90s/i] 

Train loss: 0.17841789  Val loss: 0.6407586  Train acc: 0.9603399433427762  Val acc: 0.7533333333333333


┣████████████████▋   ┫ [83.50%, 501/600, 06:53:09/08:14:47, 110.63s/i] 

Train loss: 0.13958111  Val loss: 0.6449422  Train acc: 0.9662921348314607  Val acc: 0.7586666666666667


┣█████████████████▎  ┫ [86.83%, 521/600, 07:08:30/08:13:28, 112.80s/i] 

Train loss: 0.12981798  Val loss: 0.6437071  Train acc: 0.9658119658119658  Val acc: 0.7626666666666667


┣██████████████████  ┫ [90.17%, 541/600, 07:24:03/08:12:29, 111.50s/i] 

Train loss: 0.1185142  Val loss: 0.64963555  Train acc: 0.9744318181818182  Val acc: 0.7666666666666667


┣██████████████████▋ ┫ [93.50%, 561/600, 07:39:30/08:11:27, 116.54s/i] 

Train loss: 0.10668232  Val loss: 0.655386  Train acc: 0.9761904761904762  Val acc: 0.7666666666666667


┣███████████████████▎┫ [96.83%, 581/600, 07:54:58/08:10:30, 109.04s/i] 

Train loss: 0.09134358  Val loss: 0.65654176  Train acc: 0.9858356940509915  Val acc: 0.764


┣████████████████████┫ [100.00%, 600/600, 08:08:32/08:08:32, 44.04s/i] 

Train loss: 0.08083718  Val loss: 0.6680542  Train acc: 0.9888579387186629  Val acc: 0.764
Training ended successfully, saving the results
Test loss: 1.320581  Test acc: 0.563953488372093


## Training of Gradient Surgery model with "LabelMe" target, VLCS dataset

In [9]:
n=1
results=[]
best_acc=0
for it in progress(1:num_iter)
    train_batches = py"return_data_gs(train_iterator)"
    inputs1, targets1 = train_batches[1];
    inputs1 = permutedims(inputs1.detach().numpy(), (4, 3, 2, 1));
    targets1 = targets1.detach().numpy();
    
    inputs2, targets2 = train_batches[2];
    inputs2 = permutedims(inputs2.detach().numpy(), (4, 3, 2, 1));
    targets2 = targets2.detach().numpy();
    
    inputs3, targets3 = train_batches[3];
    inputs3 = permutedims(inputs3.detach().numpy(), (4, 3, 2, 1));
    targets3 = targets3.detach().numpy();
    
    D1 = @diff model(inputs1, targets1)
    D2 = @diff model(inputs2, targets2)
    D3 = @diff model(inputs3, targets3)

    for w in Knet.params(model)
        g1 = grad(D1, w)
        g2 = grad(D2, w)
        g3 = grad(D3, w)
        MaskGS = abs.(sign.(g1).+sign.(g2).+sign.(g3)).==3
        g = (g1.+g2.+g3).*MaskGS
        update!(w, g)
    end

    if n==1 
        println("First batch is trained successfully")
        best_acc = accuracy(model(inputs_val),targets_val)
    end

    if n%check_freq == 0
        acc = accuracy(model(inputs_val),targets_val)
        if acc > best_acc
            best_acc = acc
            save(string("best_model_gs_debugging_target_", target,".jld2"),"weights",model)
        end
        temp_loss = (model(inputs1, targets1) + model(inputs2, targets2) + model(inputs3, targets3))/3
        temp_acc = (accuracy(model(inputs1),targets1) + accuracy(model(inputs2),targets2) + accuracy(model(inputs3),targets3))/3
        result=(temp_loss, model(inputs_val, targets_val), temp_acc, acc)
        println("Train loss: ",result[1],"  Val loss: ",result[2],
                "  Train acc: ",result[3],"  Val acc: ",result[4])
        push!(results,result)
    end

    (n += 1) > num_iter && break
end 
println("Training ended successfully, saving the results")
save(string("results_gs_debugging_target_", target,".jld2"),"results",results)
best_model = load(string("best_model_gs_debugging_target_", target,".jld2"))["weights"]
tst_loss = best_model(inputs_test, targets_test)
tst_acc = accuracy(model(inputs_test),targets_test)
println("Test loss: ", tst_loss, "  Test acc: ", tst_acc)

┣                    ┫ [0.17%, 1/600, 00:00/00:01, 875.30i/s] 

First batch is trained successfully


┣▋                   ┫ [3.33%, 20/600, 12:54/06:26:58, 38.35s/i] 

Train loss: 1.0689274

┣▋                   ┫ [3.50%, 21/600, 14:48/07:02:58, 114.28s/i] 

  Val loss: 1.292176  Train acc: 0.6066023867759563  Val acc: 0.528


┣█▎                  ┫ [6.83%, 41/600, 29:26/07:10:42, 107.31s/i] 

Train loss: 0.90264845  Val loss: 1.0668844  Train acc: 0.6898819900191414  Val acc: 0.6026666666666667


┣██                  ┫ [10.17%, 61/600, 44:07/07:13:59, 109.18s/i] 

Train loss: 0.81405765  Val loss: 0.9516343  Train acc: 0.690233155315207  Val acc: 0.6373333333333333


┣██▋                 ┫ [13.50%, 81/600, 58:24/07:12:35, 106.05s/i] 

Train loss: 0.5954676  Val loss: 0.8784392  Train acc: 0.7826324260581337  Val acc: 0.6573333333333333


┣███▎                ┫ [16.83%, 101/600, 01:12:31/07:10:46, 108.85s/i] 

Train loss: 0.62125415  Val loss: 0.84553033  Train acc: 0.7563920454545454  Val acc: 0.68


┣████                ┫ [20.17%, 121/600, 01:26:37/07:09:30, 106.00s/i] 

Train loss: 0.6385338  Val loss: 0.8126002  Train acc: 0.741636352935048  Val acc: 0.6933333333333334


┣████▋               ┫ [23.50%, 141/600, 01:41:03/07:09:59, 108.16s/i] 

Train loss: 0.58333415  Val loss: 0.793615  Train acc: 0.7702482702482701  Val acc: 0.7


┣█████▎              ┫ [26.83%, 161/600, 01:55:00/07:08:34, 106.78s/i] 

Train loss: 0.57171017  Val loss: 0.7824782  Train acc: 0.7586190734778842  Val acc: 0.7066666666666667


┣██████              ┫ [30.17%, 181/600, 02:09:28/07:09:11, 111.75s/i] 

Train loss: 0.5085011  Val loss: 0.77342147  Train acc: 0.7945256842894639  Val acc: 0.7066666666666667


┣██████▋             ┫ [33.50%, 201/600, 02:23:49/07:09:19, 108.09s/i] 

Train loss: 0.5675115  Val loss: 0.76907164  Train acc: 0.7687664041994751  Val acc: 0.712


┣███████▎            ┫ [36.83%, 221/600, 02:37:47/07:08:23, 102.78s/i] 

Train loss: 0.5412385  Val loss: 0.768037  Train acc: 0.7920626301022717  Val acc: 0.7106666666666667


┣████████            ┫ [40.17%, 241/600, 02:52:02/07:08:18, 105.62s/i] 

Train loss: 0.48991552  Val loss: 0.7652309  Train acc: 0.8019021739130435  Val acc: 0.7146666666666667


┣████████▋           ┫ [43.50%, 261/600, 03:06:05/07:07:47, 104.11s/i] 

Train loss: 0.52180123  Val loss: 0.76404434  Train acc: 0.8137686917575669  Val acc: 0.7106666666666667


┣█████████▎          ┫ [46.83%, 281/600, 03:20:07/07:07:18, 108.50s/i] 

Train loss: 0.4920195  Val loss: 0.76658803  Train acc: 0.7967342342342342  Val acc: 0.712


┣██████████          ┫ [50.17%, 301/600, 03:34:55/07:08:25, 102.65s/i] 

Train loss: 0.5915499  Val loss: 0.76783496  Train acc: 0.763637653401433  Val acc: 0.7146666666666667


┣██████████▋         ┫ [53.50%, 321/600, 03:49:28/07:08:54, 104.02s/i] 

Train loss: 0.46011785  Val loss: 0.77430165  Train acc: 0.8317708333333332  Val acc: 0.7133333333333334


┣███████████▎        ┫ [56.83%, 341/600, 04:03:33/07:08:31, 106.22s/i] 

Train loss: 0.5104224  Val loss: 0.7669391  Train acc: 0.7961440058479532  Val acc: 0.7226666666666667


┣████████████        ┫ [60.17%, 361/600, 04:18:03/07:08:54, 103.65s/i] 

Train loss: 0.5313437  Val loss: 0.779206  Train acc: 0.7880473163841808  Val acc: 0.7186666666666667


┣████████████▋       ┫ [63.50%, 381/600, 04:32:47/07:09:35, 106.38s/i] 

Train loss: 0.52095217  Val loss: 0.7919084  Train acc: 0.808659052753541  Val acc: 0.7213333333333334


┣█████████████▎      ┫ [66.83%, 401/600, 04:47:14/07:09:47, 107.00s/i] 

Train loss: 0.43735752  Val loss: 0.7926767  Train acc: 0.8238737535612536  Val acc: 0.7226666666666667


┣██████████████      ┫ [70.17%, 421/600, 05:01:58/07:10:21, 129.38s/i] 

Train loss: 0.38162497  Val loss: 0.797603  Train acc: 0.8639150073746312  Val acc: 0.7146666666666667


┣██████████████▋     ┫ [73.50%, 441/600, 05:17:29/07:11:57, 112.23s/i] 

Train loss: 0.48797527  Val loss: 0.81258446  Train acc: 0.826233084045584  Val acc: 0.708


┣███████████████▎    ┫ [76.83%, 461/600, 05:32:02/07:12:09, 109.88s/i] 

Train loss: 0.4360268  Val loss: 0.8143679  Train acc: 0.8233989197530865  Val acc: 0.7213333333333334


┣████████████████    ┫ [80.17%, 481/600, 05:47:25/07:13:22, 125.18s/i] 

Train loss: 0.4721366  Val loss: 0.8273966  Train acc: 0.8363809895663868  Val acc: 0.7266666666666667


┣████████████████▋   ┫ [83.50%, 501/600, 06:03:31/07:15:21, 110.26s/i] 

Train loss: 0.41534507  Val loss: 0.83391774  Train acc: 0.8422991243741591  Val acc: 0.7213333333333334


┣█████████████████▎  ┫ [86.83%, 521/600, 06:17:57/07:15:15, 111.96s/i] 

Train loss: 0.37047148  Val loss: 0.8451346  Train acc: 0.8557830459770116  Val acc: 0.7266666666666667


┣██████████████████  ┫ [90.17%, 541/600, 06:34:15/07:17:15, 144.69s/i] 

Train loss: 0.39918017  Val loss: 0.84807765  Train acc: 0.8471623563218391  Val acc: 0.7293333333333333


┣██████████████████▋ ┫ [93.50%, 561/600, 06:49:50/07:18:19, 122.05s/i] 

Train loss: 0.35507107  Val loss: 0.8567928  Train acc: 0.8899147727272728  Val acc: 0.7213333333333334


┣███████████████████▎┫ [96.83%, 581/600, 07:06:19/07:20:16, 113.80s/i] 

Train loss: 0.38424197  Val loss: 0.87818116  Train acc: 0.8453890297611104  Val acc: 0.732


┣████████████████████┫ [100.00%, 600/600, 07:21:58/07:21:58, 48.99s/i] 

Train loss: 0.3542315  Val loss: 0.89532477  Train acc: 0.8673327598756039  Val acc: 0.7253333333333334
Training ended successfully, saving the results
Test loss: 1.5912584  Test acc: 0.5465116279069767


## Training of Baseline model with "Caltech101" target, VLCS dataset

In [16]:
n=1
results=[]
best_acc=-1
for it in progress(1:num_iter)
    inputs_train, targets_train = py"return_data(train_iterator)"
    inputs_train = permutedims(inputs_train.detach().numpy(), (4, 3, 2, 1));
    targets_train = targets_train.detach().numpy();

    D = @diff model(inputs_train, targets_train);
    for w in Knet.params(model)
        g = grad(D, w)
        update!(w, g)
    end
    if n==1 
        println("First batch is trained successfully")
        best_acc = accuracy(model(inputs_val),targets_val)
    end
    if n%check_freq == 0
        acc = accuracy(model(inputs_val),targets_val)
        if acc > best_acc
            best_acc = acc
            save(string("best_model_b_debugging_target_", target,".jld2"),"weights",model)
        end
        result=(model(inputs_train, targets_train), model(inputs_val, targets_val), 
                    accuracy(model(inputs_train),targets_train), acc)
        println("Train loss: ",result[1],"  Val loss: ",result[2],
                "  Train acc: ",result[3],"  Val acc: ",result[4])
        push!(results,result)
    end

    (n += 1) > num_iter && break
end 
println("Training ended successfully, saving the results")
save(string("results_b_debugging_target_", target,".jld2"),"results",results)
best_model = load(string("best_model_b_debugging_target_", target,".jld2"))["weights"]
tst_loss = best_model(inputs_test, targets_test)
tst_acc = accuracy(model(inputs_test),targets_test)
println("Test loss: ", tst_loss, "  Test acc: ", tst_acc)

┣                    ┫ [0.17%, 1/600, 00:00/00:00, 41811.26i/s] 

First batch is trained successfully


┣▋                   ┫ [3.50%, 21/600, 17:20/08:15:20, 125.13s/i] 

Train loss: 1.100248  Val loss: 1.1851102  Train acc: 0.5675675675675675  Val acc: 0.5202247191011236


┣█▎                  ┫ [6.83%, 41/600, 35:30/08:39:28, 148.24s/i] 

Train loss: 0.8871587  Val loss: 0.9747914  Train acc: 0.6366120218579235  Val acc: 0.5853932584269663


┣██                  ┫ [10.17%, 61/600, 52:33/08:36:53, 124.90s/i] 

Train loss: 0.7889072  Val loss: 0.8905041  Train acc: 0.6557377049180327  Val acc: 0.6292134831460674


┣██▋                 ┫ [13.50%, 81/600, 01:11:23/08:48:45, 122.26s/i] 

Train loss: 0.75684863  Val loss: 0.84599566  Train acc: 0.6731301939058172  Val acc: 0.6494382022471911


┣███▎                ┫ [16.83%, 101/600, 01:25:56/08:30:28, 116.90s/i] 

Train loss: 0.643907  Val loss: 0.8145684  Train acc: 0.7444444444444445  Val acc: 0.6640449438202247


┣████                ┫ [20.17%, 121/600, 01:40:36/08:18:49, 122.11s/i] 

Train loss: 0.66162646  Val loss: 0.7983411  Train acc: 0.7446236559139785  Val acc: 0.6651685393258427


┣████▋               ┫ [23.50%, 141/600, 01:55:44/08:12:31, 117.80s/i] 

Train loss: 0.6509478  Val loss: 0.78106457  Train acc: 0.7417582417582418  Val acc: 0.6786516853932584


┣█████▎              ┫ [26.83%, 161/600, 02:10:09/08:05:02, 115.05s/i] 

Train loss: 0.6022612  Val loss: 0.77166945  Train acc: 0.7967032967032966  Val acc: 0.6719101123595506


┣██████              ┫ [30.17%, 181/600, 02:24:33/07:59:11, 112.56s/i] 

Train loss: 0.54432464  Val loss: 0.7615011  Train acc: 0.8102981029810298  Val acc: 0.6797752808988764


┣██████▋             ┫ [33.50%, 201/600, 02:39:20/07:55:38, 112.72s/i] 

Train loss: 0.56665725  Val loss: 0.7574805  Train acc: 0.7866666666666666  Val acc: 0.6921348314606741


┣███████▎            ┫ [36.83%, 221/600, 02:53:39/07:51:27, 117.46s/i] 

Train loss: 0.49691352  Val loss: 0.749475  Train acc: 0.8135135135135135  Val acc: 0.6943820224719102


┣████████            ┫ [40.17%, 241/600, 03:07:58/07:47:59, 121.23s/i] 

Train loss: 0.54275876  Val loss: 0.743401  Train acc: 0.8010752688172043  Val acc: 0.6966292134831461


┣████████▋           ┫ [43.50%, 261/600, 03:22:33/07:45:37, 112.99s/i] 

Train loss: 0.47579375  Val loss: 0.74172664  Train acc: 0.8191780821917808  Val acc: 0.7


┣█████████▎          ┫ [46.83%, 281/600, 03:37:02/07:43:26, 113.84s/i] 

Train loss: 0.45688865  Val loss: 0.7402894  Train acc: 0.8246575342465754  Val acc: 0.701123595505618


┣██████████          ┫ [50.17%, 301/600, 03:52:09/07:42:45, 123.66s/i] 

Train loss: 0.41318592  Val loss: 0.73886305  Train acc: 0.8602739726027397  Val acc: 0.7067415730337079


┣██████████▋         ┫ [53.50%, 321/600, 04:06:39/07:41:01, 114.07s/i] 

Train loss: 0.4347422  Val loss: 0.7393157  Train acc: 0.8438356164383561  Val acc: 0.6943820224719102


┣███████████▎        ┫ [56.83%, 341/600, 04:21:05/07:39:23, 114.72s/i] 

Train loss: 0.42199406  Val loss: 0.74156946  Train acc: 0.8694444444444445  Val acc: 0.6932584269662921


┣████████████        ┫ [60.17%, 361/600, 04:35:32/07:37:57, 113.37s/i] 

Train loss: 0.39448032  Val loss: 0.74423116  Train acc: 0.8528610354223434  Val acc: 0.7033707865168539


┣████████████▋       ┫ [63.50%, 381/600, 04:50:06/07:36:50, 117.92s/i] 

Train loss: 0.4189935  Val loss: 0.74189633  Train acc: 0.8743169398907104  Val acc: 0.6966292134831461


┣█████████████▎      ┫ [66.83%, 401/600, 05:04:36/07:35:45, 118.12s/i] 

Train loss: 0.34758076  Val loss: 0.74519086  Train acc: 0.8879781420765027  Val acc: 0.7044943820224719


┣██████████████      ┫ [70.17%, 421/600, 05:19:15/07:34:59, 113.10s/i] 

Train loss: 0.34890068  Val loss: 0.7463264  Train acc: 0.8732782369146006  Val acc: 0.7033707865168539


┣██████████████▋     ┫ [73.50%, 441/600, 05:33:29/07:33:43, 112.29s/i] 

Train loss: 0.33046404  Val loss: 0.7544022  Train acc: 0.894878706199461  Val acc: 0.701123595505618


┣███████████████▎    ┫ [76.83%, 461/600, 05:47:49/07:32:42, 114.07s/i] 

Train loss: 0.30123672  Val loss: 0.7557  Train acc: 0.9178082191780822  Val acc: 0.6887640449438203


┣████████████████    ┫ [80.17%, 481/600, 06:02:17/07:31:54, 112.81s/i] 

Train loss: 0.26533648  Val loss: 0.7630548  Train acc: 0.9368131868131868  Val acc: 0.702247191011236


┣████████████████▋   ┫ [83.50%, 501/600, 06:16:37/07:31:02, 115.55s/i] 

Train loss: 0.25200474  Val loss: 0.76673293  Train acc: 0.943089430894309  Val acc: 0.6955056179775281


┣█████████████████▎  ┫ [86.83%, 521/600, 06:32:24/07:31:54, 118.90s/i] 

Train loss: 0.23819013  Val loss: 0.77519506  Train acc: 0.9371584699453552  Val acc: 0.7


┣██████████████████  ┫ [90.17%, 541/600, 06:46:43/07:31:05, 114.59s/i] 

Train loss: 0.20323408  Val loss: 0.7776526  Train acc: 0.9614325068870524  Val acc: 0.6932584269662921


┣██████████████████▋ ┫ [93.50%, 561/600, 07:01:43/07:31:02, 114.10s/i] 

Train loss: 0.19736783  Val loss: 0.7847305  Train acc: 0.953168044077135  Val acc: 0.7056179775280899


┣███████████████████▎┫ [96.83%, 581/600, 07:17:52/07:32:11, 119.00s/i] 

Train loss: 0.1836528  Val loss: 0.8009943  Train acc: 0.972972972972973  Val acc: 0.6955056179775281


┣████████████████████┫ [100.00%, 600/600, 07:32:06/07:32:06, 44.22s/i] 

Train loss: 0.16312927  Val loss: 0.8030285  Train acc: 0.9726775956284153  Val acc: 0.6876404494382022
Training ended successfully, saving the results
Test loss: 0.3394837  Test acc: 0.9110169491525424


## Training of Gradient Surgery model with "Caltech101" target, VLCS dataset

In [23]:
n=1
results=[]
best_acc=0
for it in progress(1:num_iter)
    train_batches = py"return_data_gs(train_iterator)"
    inputs1, targets1 = train_batches[1];
    inputs1 = permutedims(inputs1.detach().numpy(), (4, 3, 2, 1));
    targets1 = targets1.detach().numpy();
    
    inputs2, targets2 = train_batches[2];
    inputs2 = permutedims(inputs2.detach().numpy(), (4, 3, 2, 1));
    targets2 = targets2.detach().numpy();
    
    inputs3, targets3 = train_batches[3];
    inputs3 = permutedims(inputs3.detach().numpy(), (4, 3, 2, 1));
    targets3 = targets3.detach().numpy();
    
    D1 = @diff model(inputs1, targets1)
    D2 = @diff model(inputs2, targets2)
    D3 = @diff model(inputs3, targets3)

    for w in Knet.params(model)
        g1 = grad(D1, w)
        g2 = grad(D2, w)
        g3 = grad(D3, w)
        MaskGS = abs.(sign.(g1).+sign.(g2).+sign.(g3)).==3
        g = (g1.+g2.+g3).*MaskGS
        update!(w, g)
    end

    if n==1 
        println("First batch is trained successfully")
        best_acc = accuracy(model(inputs_val),targets_val)
    end

    if n%check_freq == 0
        acc = accuracy(model(inputs_val),targets_val)
        if acc > best_acc
            best_acc = acc
            save(string("best_model_gs_debugging_target_", target,".jld2"),"weights",model)
        end
        temp_loss = (model(inputs1, targets1) + model(inputs2, targets2) + model(inputs3, targets3))/3
        temp_acc = (accuracy(model(inputs1),targets1) + accuracy(model(inputs2),targets2) + accuracy(model(inputs3),targets3))/3
        result=(temp_loss, model(inputs_val, targets_val), temp_acc, acc)
        println("Train loss: ",result[1],"  Val loss: ",result[2],
                "  Train acc: ",result[3],"  Val acc: ",result[4])
        push!(results,result)
    end

    (n += 1) > num_iter && break
end 
println("Training ended successfully, saving the results")
save(string("results_gs_debugging_target_", target,".jld2"),"results",results)
best_model = load(string("best_model_gs_debugging_target_", target,".jld2"))["weights"]
tst_loss = best_model(inputs_test, targets_test)
tst_acc = accuracy(model(inputs_test),targets_test)
println("Test loss: ", tst_loss, "  Test acc: ", tst_acc)

┣                    ┫ [0.17%, 1/600, 00:00/00:00, 49485.35i/s] 

First batch is trained successfully


┣▋                   ┫ [3.50%, 21/600, 16:45/07:58:42, 122.16s/i] 

Train loss: 1.2514615  Val loss: 1.3190125  Train acc: 0.5023298369755063  Val acc: 0.48314606741573035


┣█▎                  ┫ [6.83%, 41/600, 32:51/08:00:40, 122.12s/i] 

Train loss: 1.024402  Val loss: 1.1070992  Train acc: 0.5762955476022151  Val acc: 0.549438202247191


┣██                  ┫ [10.17%, 61/600, 49:10/08:03:36, 133.98s/i] 

Train loss: 0.9403855  Val loss: 1.007746  Train acc: 0.5859909516920482  Val acc: 0.5786516853932584


┣██▋                 ┫ [13.50%, 81/600, 01:04:59/08:01:19, 114.94s/i] 

Train loss: 0.87312573  Val loss: 0.9490521  Train acc: 0.6302914862914863  Val acc: 0.6078651685393258


┣███▎                ┫ [16.83%, 101/600, 01:20:58/08:00:58, 117.04s/i] 

Train loss: 0.7654834  Val loss: 0.9068308  Train acc: 0.691762499559066  Val acc: 0.6269662921348315


┣████                ┫ [20.17%, 121/600, 01:35:47/07:54:59, 113.02s/i] 

Train loss: 0.73853284  Val loss: 0.8750912  Train acc: 0.7096873808241919  Val acc: 0.647191011235955


┣████▋               ┫ [23.50%, 141/600, 01:50:38/07:50:45, 114.66s/i] 

Train loss: 0.76730824  Val loss: 0.85057056  Train acc: 0.686953647744442  Val acc: 0.6550561797752809


┣█████▎              ┫ [26.83%, 161/600, 02:06:59/07:53:12, 120.86s/i] 

Train loss: 0.71986705  Val loss: 0.8335488  Train acc: 0.7187610544523587  Val acc: 0.6617977528089888


┣██████              ┫ [30.17%, 181/600, 02:22:08/07:51:11, 124.24s/i] 

Train loss: 0.6531771  Val loss: 0.8198903  Train acc: 0.724040150888797  Val acc: 0.6696629213483146


┣██████▋             ┫ [33.50%, 201/600, 02:40:46/07:59:53, 137.05s/i] 

Train loss: 0.68129086  Val loss: 0.8104791  Train acc: 0.7251035769289169  Val acc: 0.6786516853932584


┣███████▎            ┫ [36.83%, 221/600, 02:55:31/07:56:32, 116.32s/i] 

Train loss: 0.6080225  Val loss: 0.8004067  Train acc: 0.761698082010582  Val acc: 0.6797752808988764


┣████████            ┫ [40.17%, 241/600, 03:10:30/07:54:17, 122.13s/i] 

Train loss: 0.66229206  Val loss: 0.79343003  Train acc: 0.7479542998301683  Val acc: 0.6786516853932584


┣████████▋           ┫ [43.50%, 261/600, 03:26:05/07:53:46, 119.06s/i] 

Train loss: 0.6125135  Val loss: 0.7865047  Train acc: 0.7530666096134646  Val acc: 0.6876404494382022


┣█████████▎          ┫ [46.83%, 281/600, 03:41:19/07:52:33, 118.46s/i] 

Train loss: 0.6016073  Val loss: 0.783109  Train acc: 0.7543786873156342  Val acc: 0.6910112359550562


┣██████████          ┫ [50.17%, 301/600, 03:56:42/07:51:50, 129.51s/i] 

Train loss: 0.56659013  Val loss: 0.7787506  Train acc: 0.8040165180837019  Val acc: 0.6932584269662921


┣██████████▋         ┫ [53.50%, 321/600, 04:12:04/07:51:09, 122.76s/i] 

Train loss: 0.596385  Val loss: 0.77853274  Train acc: 0.754343620738839  Val acc: 0.698876404494382


┣███████████▎        ┫ [56.83%, 341/600, 04:27:05/07:49:56, 114.99s/i] 

Train loss: 0.58737403  Val loss: 0.7757251  Train acc: 0.7820684523809524  Val acc: 0.6932584269662921


┣████████████        ┫ [60.17%, 361/600, 04:42:10/07:48:59, 117.88s/i] 

Train loss: 0.5783918  Val loss: 0.77272594  Train acc: 0.7681643272851764  Val acc: 0.6910112359550562


┣████████████▋       ┫ [63.50%, 381/600, 04:57:06/07:47:52, 115.80s/i] 

Train loss: 0.6116985  Val loss: 0.7707185  Train acc: 0.7694126106194691  Val acc: 0.6932584269662921


┣█████████████▎      ┫ [66.83%, 401/600, 05:12:07/07:47:00, 115.02s/i] 

Train loss: 0.5242208  Val loss: 0.7672685  Train acc: 0.7848480636011294  Val acc: 0.6921348314606741


┣██████████████      ┫ [70.17%, 421/600, 05:27:21/07:46:32, 119.84s/i] 

Train loss: 0.5394117  Val loss: 0.7655895  Train acc: 0.792154806015693  Val acc: 0.6955056179775281


┣██████████████▋     ┫ [73.50%, 441/600, 05:42:26/07:45:54, 121.41s/i] 

Train loss: 0.5746197  Val loss: 0.76641023  Train acc: 0.7895083029875503  Val acc: 0.702247191011236


┣███████████████▎    ┫ [76.83%, 461/600, 05:57:31/07:45:19, 118.52s/i] 

Train loss: 0.521531  Val loss: 0.76514864  Train acc: 0.8162522923034176  Val acc: 0.7


┣████████████████    ┫ [80.17%, 481/600, 06:13:09/07:45:28, 116.62s/i] 

Train loss: 0.47811934  Val loss: 0.7648253  Train acc: 0.8035714285714285  Val acc: 0.702247191011236


┣████████████████▋   ┫ [83.50%, 501/600, 06:28:00/07:44:40, 117.37s/i] 

Train loss: 0.48281512  Val loss: 0.76746184  Train acc: 0.8314502542820407  Val acc: 0.702247191011236


┣█████████████████▎  ┫ [86.83%, 521/600, 06:43:39/07:44:51, 121.57s/i] 

Train loss: 0.47612324  Val loss: 0.7631234  Train acc: 0.836830320516075  Val acc: 0.7078651685393258


┣██████████████████  ┫ [90.17%, 541/600, 06:58:48/07:44:29, 121.66s/i] 

Train loss: 0.4366006  Val loss: 0.76181215  Train acc: 0.8555212224108658  Val acc: 0.7078651685393258


┣██████████████████▋ ┫ [93.50%, 561/600, 07:13:57/07:44:07, 117.98s/i] 

Train loss: 0.4424077  Val loss: 0.7629623  Train acc: 0.8251658957692335  Val acc: 0.7089887640449438


┣███████████████████▎┫ [96.83%, 581/600, 07:28:42/07:43:23, 117.30s/i] 

Train loss: 0.45861435  Val loss: 0.7635531  Train acc: 0.837833660261158  Val acc: 0.7067415730337079


┣████████████████████┫ [100.00%, 600/600, 07:41:32/07:41:32, 38.69s/i] 

Train loss: 0.43274084  Val loss: 0.7694361  Train acc: 0.8430530973451327  Val acc: 0.7044943820224719
Training ended successfully, saving the results
Test loss: 0.41374558  Test acc: 0.8771186440677966


In [24]:
num_iter=400
n=1
results=[]
best_acc=0
for it in progress(1:num_iter)
    train_batches = py"return_data_gs(train_iterator)"
    inputs1, targets1 = train_batches[1];
    inputs1 = permutedims(inputs1.detach().numpy(), (4, 3, 2, 1));
    targets1 = targets1.detach().numpy();
    
    inputs2, targets2 = train_batches[2];
    inputs2 = permutedims(inputs2.detach().numpy(), (4, 3, 2, 1));
    targets2 = targets2.detach().numpy();
    
    inputs3, targets3 = train_batches[3];
    inputs3 = permutedims(inputs3.detach().numpy(), (4, 3, 2, 1));
    targets3 = targets3.detach().numpy();
    
    D1 = @diff model(inputs1, targets1)
    D2 = @diff model(inputs2, targets2)
    D3 = @diff model(inputs3, targets3)

    for w in Knet.params(model)
        g1 = grad(D1, w)
        g2 = grad(D2, w)
        g3 = grad(D3, w)
        MaskGS = abs.(sign.(g1).+sign.(g2).+sign.(g3)).==3
        g = (g1.+g2.+g3).*MaskGS
        update!(w, g)
    end

    if n==1 
        println("First batch is trained successfully")
        best_acc = accuracy(model(inputs_val),targets_val)
    end

    if n%check_freq == 0
        acc = accuracy(model(inputs_val),targets_val)
        if acc > best_acc
            best_acc = acc
            save(string("best_model_gs_debugging_target_", target,".jld2"),"weights",model)
        end
        temp_loss = (model(inputs1, targets1) + model(inputs2, targets2) + model(inputs3, targets3))/3
        temp_acc = (accuracy(model(inputs1),targets1) + accuracy(model(inputs2),targets2) + accuracy(model(inputs3),targets3))/3
        result=(temp_loss, model(inputs_val, targets_val), temp_acc, acc)
        println("Train loss: ",result[1],"  Val loss: ",result[2],
                "  Train acc: ",result[3],"  Val acc: ",result[4])
        push!(results,result)
    end

    (n += 1) > num_iter && break
end 
println("Training ended successfully, saving the results")
save(string("results_gs_debugging_target_", target,".jld2"),"results",results)
best_model = load(string("best_model_gs_debugging_target_", target,".jld2"))["weights"]
tst_loss = best_model(inputs_test, targets_test)
tst_acc = accuracy(model(inputs_test),targets_test)
println("Test loss: ", tst_loss, "  Test acc: ", tst_acc)

┣                    ┫ [0.25%, 1/400, 00:00/00:00, 1296.18i/s] 

First batch is trained successfully


┣█                   ┫ [5.25%, 21/400, 15:14/04:50:11, 114.33s/i] 

Train loss: 0.44486594  Val loss: 0.76705664  Train acc: 0.860558444161479  Val acc: 0.7067415730337079


┣██                  ┫ [10.25%, 41/400, 29:40/04:49:29, 112.30s/i] 

Train loss: 0.42777586  Val loss: 0.7674101  Train acc: 0.8325241533406856  Val acc: 0.7078651685393258


┣███                 ┫ [15.25%, 61/400, 44:03/04:48:50, 113.09s/i] 

Train loss: 0.38727894  Val loss: 0.7696606  Train acc: 0.8749194847020934  Val acc: 0.701123595505618


┣████                ┫ [20.25%, 81/400, 58:26/04:48:36, 112.30s/i] 

Train loss: 0.39649382  Val loss: 0.7714562  Train acc: 0.8742071458139525  Val acc: 0.7089887640449438


┣█████               ┫ [25.25%, 101/400, 01:12:53/04:48:39, 112.60s/i] 

Train loss: 0.45208156  Val loss: 0.77418685  Train acc: 0.8372886473429952  Val acc: 0.7112359550561798


┣██████              ┫ [30.25%, 121/400, 01:27:06/04:47:57, 111.39s/i] 

Train loss: 0.33019066  Val loss: 0.78196895  Train acc: 0.9034597947983775  Val acc: 0.7067415730337079


┣███████             ┫ [35.25%, 141/400, 01:41:30/04:47:56, 114.28s/i] 

Train loss: 0.35138372  Val loss: 0.78369796  Train acc: 0.8879361543221732  Val acc: 0.702247191011236


┣████████            ┫ [40.25%, 161/400, 01:55:48/04:47:43, 112.44s/i] 

Train loss: 0.35266486  Val loss: 0.78697973  Train acc: 0.879671893630376  Val acc: 0.7112359550561798


┣█████████           ┫ [45.25%, 181/400, 02:10:10/04:47:39, 111.07s/i] 

Train loss: 0.35665953  Val loss: 0.79091483  Train acc: 0.8714504310344827  Val acc: 0.7112359550561798


┣██████████          ┫ [50.25%, 201/400, 02:24:55/04:48:24, 112.43s/i] 

Train loss: 0.30091763  Val loss: 0.7885118  Train acc: 0.9061471641502129  Val acc: 0.7067415730337079


┣███████████         ┫ [55.25%, 221/400, 02:39:10/04:48:04, 114.45s/i] 

Train loss: 0.355041  Val loss: 0.7952634  Train acc: 0.8762622892477423  Val acc: 0.7089887640449438


┣████████████        ┫ [60.25%, 241/400, 02:53:34/04:48:04, 111.89s/i] 

Train loss: 0.28731456  Val loss: 0.79488486  Train acc: 0.9270762711864405  Val acc: 0.7112359550561798


┣█████████████       ┫ [65.25%, 261/400, 03:07:53/04:47:56, 114.21s/i] 

Train loss: 0.26922086  Val loss: 0.8026164  Train acc: 0.9123794188314779  Val acc: 0.7067415730337079


┣██████████████      ┫ [70.25%, 281/400, 03:22:11/04:47:48, 110.71s/i] 

Train loss: 0.32188365  Val loss: 0.8111261  Train acc: 0.8870040670203582  Val acc: 0.7056179775280899


┣███████████████     ┫ [75.25%, 301/400, 03:36:36/04:47:51, 111.55s/i] 

Train loss: 0.29973832  Val loss: 0.8159335  Train acc: 0.9073376870032726  Val acc: 0.7


┣████████████████    ┫ [80.25%, 321/400, 03:50:53/04:47:42, 110.68s/i] 

Train loss: 0.31018618  Val loss: 0.819272  Train acc: 0.9091900591936711  Val acc: 0.7033707865168539


┣█████████████████   ┫ [85.25%, 341/400, 04:05:20/04:47:47, 110.82s/i] 

Train loss: 0.27631187  Val loss: 0.8240348  Train acc: 0.9046792328042327  Val acc: 0.7033707865168539


┣██████████████████  ┫ [90.25%, 361/400, 04:19:35/04:47:38, 112.23s/i] 

Train loss: 0.27589992  Val loss: 0.83212143  Train acc: 0.9198818897637796  Val acc: 0.6910112359550562


┣███████████████████ ┫ [95.25%, 381/400, 04:34:02/04:47:42, 114.94s/i] 

Train loss: 0.19290823  Val loss: 0.8324209  Train acc: 0.9657955467640496  Val acc: 0.6943820224719102


┣████████████████████┫ [100.00%, 400/400, 04:46:38/04:46:38, 40.41s/i] 

Train loss: 0.25224972  Val loss: 0.84034234  Train acc: 0.9249880735536474  Val acc: 0.698876404494382
Training ended successfully, saving the results
Test loss: 0.36224905  Test acc: 0.885593220338983
