In [1]:
server_root_path = "/content/drive/MyDrive"

In [None]:
import cv2
from skimage import io
import numpy as np
import os
import glob
from tqdm import tqdm

def crop_and_pad(img1, size=224.0, max_=0):
    # print('## File: img_utills.py | Function: crop_and_pad ##')
    h = img1.shape[0]
    w = img1.shape[1]

    # Maintain the same aspect ratio and resize the image

    if h>w:
        ratio = float(size)/float(h)
        new_h = int(size)
        new_w = int(ratio*w)
    else:
        ratio = float(size)/float(w)
        # print(ratio)
        new_h = int(ratio*h)
        new_w = int(size)

    # print(img1.shape)
    # print('newsize', new_h, new_w)
    img_resized = cv2.resize(img1, (new_w, new_h))

    pad_h = (size-new_h)
    pad_h_start = int(pad_h//2)
    pad_h_stop = int(pad_h - pad_h_start)
    pad_w = (size-new_w)
    pad_w_start = int(pad_w//2)
    pad_w_stop = int(pad_w - pad_w_start)

    img_cropped = cv2.copyMakeBorder(img_resized,pad_h_start,pad_h_stop,pad_w_start,pad_w_stop, cv2.BORDER_CONSTANT, value=int(max_))

    return img_cropped

def save_image(img_path_read, data_dir, category, cat_it, dataset_name, iteration_no, phase, resolution):
    # print('## File: dataset_utils.py | Function: save_image ##')
    image = io.imread(img_path_read)
    crop = crop_and_pad(image, size=resolution)
    img_name = os.path.join(data_dir, phase,  '_'.join(['category', category, 'category_number', str(cat_it), 'dataset', dataset_name, str(iteration_no)])) + '.png'
    image =  image.astype('uint8')
    io.imsave(img_name, crop)

def save_data(server_root_path, dataset_dir, dataset_exp_name, images_folder_name, datasets, C, C_dash, train_val_split, resolution):
    # print('## File: dataset_utils.py | Function: save_data ##')
    data_dir = os.path.join(server_root_path, dataset_dir, dataset_exp_name, images_folder_name)
    print(data_dir)
    # print data_dir
    if os.path.exists(data_dir):
        os.system('rm -rf ' + data_dir + '/*')
    else:
        os.mkdir(data_dir)

    categories = list(C)
    categories.extend(C_dash)

    os.mkdir(os.path.join(data_dir, 'train'))
    os.mkdir(os.path.join(data_dir, 'val'))
    train_iteration_no = 0
    val_iteration_no = 0

    for cat_it, category in tqdm(list(enumerate(categories))):

        for dataset_name in datasets:

            imgs_path = np.array(glob.glob(os.path.join(server_root_path, dataset_dir, dataset_name, category) + '/*'))
            np.random.shuffle(imgs_path)
            split_pos = int(train_val_split*len(imgs_path))

            imgs_path_train = imgs_path[:split_pos]
            imgs_path_val = imgs_path[split_pos:]

            #Save train images
            for img_path_read in imgs_path_train:
                # print(img_path_read)
                save_image(img_path_read, data_dir, category, cat_it, dataset_name, train_iteration_no, 'train', resolution)
                train_iteration_no = train_iteration_no + 1

            #Save val images
            for img_path_read in imgs_path_val:
                val_iteration_no = val_iteration_no + 1
                save_image(img_path_read, data_dir, category, cat_it, dataset_name, val_iteration_no, 'val', resolution)

    temp_paths_train = os.listdir(os.path.join(data_dir, 'train'))
    temp_paths_train = [os.path.join(dataset_dir, dataset_exp_name, images_folder_name, 'train', x) for x in temp_paths_train]
    imgs_path_train = np.array(temp_paths_train)

    temp_paths_val = os.listdir(os.path.join(data_dir, 'val'))
    temp_paths_val = [os.path.join(dataset_dir, dataset_exp_name, images_folder_name, 'val', x) for x in temp_paths_val]
    imgs_path_val = np.array(temp_paths_val)

    if not os.path.exists(os.path.join(server_root_path, dataset_dir, dataset_exp_name, 'index_lists')):
        os.mkdir(os.path.join(server_root_path, dataset_dir, dataset_exp_name, 'index_lists'))

    #Save index list train
    save_path = os.path.join(server_root_path, dataset_dir, dataset_exp_name, 'index_lists')
    np.save(save_path + '/' + images_folder_name + '_index_list_' + 'train.npy',imgs_path_train)

    #Save index list val
    save_path = os.path.join(server_root_path, dataset_dir, dataset_exp_name, 'index_lists')
    np.save(save_path + '/' + images_folder_name + '_index_list_' + 'val.npy', imgs_path_val)

In [None]:
import os
import sys
import numpy as np
from tqdm import tqdm
from natsort import natsorted

#Place where all data is stored
dataset_dir = 'Office-31'

dataset_exp_names = ['usfda_office_31_DtoA']
datasets_sources = [['dslr']]
datasets_targets = ['amazon']

C = ['back_pack', 'calculator', 'keyboard', 'monitor', 'mouse', 'mug', 'bike', 'laptop_computer', 'headphones', 'projector']
Cs_dash = ['bike_helmet', 'bookcase', 'bottle', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'letter_tray', 'mobile_phone', 'paper_notebook']
Ct_dash = ['pen', 'phone', 'printer', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']

#number of shared classes between source and target
num_shared_classes = len(C)

#number of unknown classes in target domain
num_unknown_target_classes = len(Cs_dash)

#number of unknown classes in source domain
num_unknown_source_classes = len(Ct_dash)

for dataset_exp_name, datasets_source, datasets_target in tqdm(list(zip(dataset_exp_names, datasets_sources, datasets_targets))):

    print('running', dataset_exp_name)

    if dataset_exp_name.split('_')[-1] == 'saito':
      resolution = 32
    else:
      resolution = 224

    source_train_val_split = 0.9
    target_train_val_split = 1

    # Create a folder inside the dataset experiment folder server_root_path, dataset_dir, dataset_exp_name
    if not os.path.exists(os.path.join(server_root_path, dataset_dir, dataset_exp_name)):
        os.mkdir(os.path.join(server_root_path, dataset_dir, dataset_exp_name))
    else:
        os.system('rm -rf ' + os.path.join(server_root_path, dataset_dir, dataset_exp_name) + '/*')

    num_datasets = len(datasets_source) + 1
    all_datasets = datasets_source + [datasets_target]

    print('shared_classes: {}'.format(C))
    print('source_private_classes: {}'.format(Cs_dash))
    print('target_private_classes: {}'.format(Ct_dash))

    #Create Source Data
    save_data(server_root_path, dataset_dir, dataset_exp_name, 'source_images', datasets_source, C, Cs_dash, source_train_val_split, resolution=resolution)

    #Create Target Data
    save_data(server_root_path, dataset_dir, dataset_exp_name, 'target_images', [datasets_target], C, Ct_dash, target_train_val_split, resolution=resolution)

  0%|          | 0/1 [00:00<?, ?it/s]

running usfda_office_31_DtoA
shared_classes: ['back_pack', 'calculator', 'keyboard', 'monitor', 'mouse', 'mug', 'bike', 'laptop_computer', 'headphones', 'projector']
source_private_classes: ['bike_helmet', 'bookcase', 'bottle', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'letter_tray', 'mobile_phone', 'paper_notebook']
target_private_classes: ['pen', 'phone', 'printer', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']
/content/drive/MyDrive/Office-31/usfda_office_31_DtoA/source_images



  0%|          | 0/20 [00:00<?, ?it/s][A
  5%|▌         | 1/20 [00:00<00:12,  1.50it/s][A
 10%|█         | 2/20 [00:01<00:11,  1.56it/s][A
 15%|█▌        | 3/20 [00:01<00:09,  1.73it/s][A
 20%|██        | 4/20 [00:02<00:12,  1.31it/s][A
 25%|██▌       | 5/20 [00:03<00:10,  1.47it/s][A
 30%|███       | 6/20 [00:03<00:08,  1.73it/s][A
 35%|███▌      | 7/20 [00:04<00:08,  1.45it/s][A
 40%|████      | 8/20 [00:05<00:09,  1.24it/s][A
 45%|████▌     | 9/20 [00:06<00:08,  1.35it/s][A
 50%|█████     | 10/20 [00:07<00:08,  1.20it/s][A
 55%|█████▌    | 11/20 [00:08<00:08,  1.04it/s][A
 60%|██████    | 12/20 [00:09<00:06,  1.21it/s][A
 65%|██████▌   | 13/20 [00:09<00:05,  1.21it/s][A
 70%|███████   | 14/20 [00:10<00:04,  1.30it/s][A
 75%|███████▌  | 15/20 [00:11<00:03,  1.38it/s][A
 80%|████████  | 16/20 [00:11<00:02,  1.39it/s][A
 85%|████████▌ | 17/20 [00:12<00:02,  1.34it/s][A
 90%|█████████ | 18/20 [00:13<00:01,  1.28it/s][A
 95%|█████████▌| 19/20 [00:15<00:01,  1.05s/it]

/content/drive/MyDrive/Office-31/usfda_office_31_DtoA/target_images



  0%|          | 0/21 [00:00<?, ?it/s][A
  5%|▍         | 1/21 [00:03<01:07,  3.35s/it][A
 10%|▉         | 2/21 [00:06<00:59,  3.15s/it][A
 14%|█▍        | 3/21 [00:09<00:56,  3.13s/it][A
 19%|█▉        | 4/21 [00:13<00:56,  3.32s/it][A
 24%|██▍       | 5/21 [00:16<00:54,  3.43s/it][A
 29%|██▊       | 6/21 [00:19<00:50,  3.36s/it][A
 33%|███▎      | 7/21 [00:22<00:43,  3.12s/it][A
 38%|███▊      | 8/21 [00:26<00:41,  3.23s/it][A
 43%|████▎     | 9/21 [00:29<00:39,  3.30s/it][A
 48%|████▊     | 10/21 [00:32<00:34,  3.16s/it][A
 52%|█████▏    | 11/21 [00:35<00:30,  3.04s/it][A
 57%|█████▋    | 12/21 [00:38<00:27,  3.07s/it][A
 62%|██████▏   | 13/21 [00:41<00:25,  3.20s/it][A
 67%|██████▋   | 14/21 [00:45<00:22,  3.26s/it][A
 71%|███████▏  | 15/21 [00:48<00:19,  3.21s/it][A
  return func(*args, **kwargs)

 81%|████████  | 17/21 [00:53<00:11,  2.99s/it][A
 86%|████████▌ | 18/21 [00:57<00:09,  3.16s/it][A
 90%|█████████ | 19/21 [01:00<00:06,  3.20s/it][A
 95%|█████████▌|

In [2]:
import matplotlib
matplotlib.use('Agg')
import os
import sys
import numpy as np
from skimage import io
from tqdm import tqdm
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
from glob import glob
import torch
from natsort import natsorted
import torch
import re

# Set device to GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# 2-way or 3-way splicing
n_way_splice = 2

#Place where all data is stored
# dataset_dir = 'data/digits'
dataset_dir = 'Office-31'

dataset_exp_names = ['usfda_office_31_DtoA']
datasets_sources = [['dslr']]
datasets_targets = ['amazon']

C = ['back_pack', 'calculator', 'keyboard', 'monitor', 'mouse', 'mug', 'bike', 'laptop_computer', 'headphones', 'projector']
Cs_dash = ['bike_helmet', 'bookcase', 'bottle', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'letter_tray', 'mobile_phone', 'paper_notebook']
Ct_dash = ['pen', 'phone', 'printer', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']

#number of shared classes between source and target
num_shared_classes = len(C)

#number of unknown classes in target domain
num_unknown_target_classes = len(Cs_dash)

#number of unknown classes in source domain
num_unknown_source_classes = len(Ct_dash)

# Negative categories
num_source_classes = len(C) + len(Cs_dash)
temp_negative_category_dict = {}
c = 0
for i in range(num_source_classes):
	for j in range(i+1, num_source_classes):
		temp_negative_category_dict[(i, j)] = c
		c += 1

negative_category_dict = {}

for key in temp_negative_category_dict:
	negative_category_dict[key] = temp_negative_category_dict[key]
	negative_category_dict[(key[1], key[0])] = temp_negative_category_dict[key]


def generate_spline(num_points, points = None, vertical = False, integer = False):

	'''
		Performs spline interpolation between points, and returns num_points coordinates.
		Args:
			points: list of (y,x) coordinates. If none, then samples from num_points
			vertical: True if points are y coordinates (False if points are x coordinates)
			integer: True if the function should return integer points
	'''

	HMAX, WMAX = num_points, num_points
	window = 30
	center = (HMAX/2, WMAX/2)

	rand_center_point = (np.random.randint(low=int(center[0]-window), high=int(center[0]+window)), np.random.randint(low=int(center[1]-(window/3)), high=int(center[1]+(window/3))))

	points = [(0, np.random.randint(HMAX)), rand_center_point, (WMAX-1, np.random.randint(HMAX))]

	px = [p[0] for p in points]
	px.extend([0, 223])

	py = [p[1] for p in points]
	py.extend([0, 223])

	# spl = spline_interpolation([p[0] for p in points], [p[1] for p in points], range(num_points), order=7, kind='smoothest')
	spl = interp1d([p[0] for p in points], [p[1] for p in points], kind='quadratic')(range(num_points))

	x = list(range(num_points))
	y = [p for p in spl]

	if integer:
		x = [int(a) for a in x]
		y = [int(a) for a in y]

	if not(vertical):
		return (x, y), [px, py]
	else:
		return (y, x), [py, px]


def get_negative_image(image1, image2, spline = None, vertical_division = False):

	'''
		Returns 2 negative images by merging 2 images horizontally or vertically (vertical_division).
	'''

	if not(vertical_division):

		mask1 = torch.ones(image1.shape).to(device) # Image to the left of spline
		for (x, y) in spline:
			mask1[y:, x] = 0
		mask2 = 1 - mask1 # Image to the right spline

		# print(mask1)
		# print(mask2)

		# io.imsave('temp_mask1.jpg', (mask1.cpu().numpy() * 255).astype(np.uint8))
		# io.imsave('temp_mask2.jpg', (mask2.cpu().numpy() * 255).astype(np.uint8))

		mask1 = mask1.float()
		mask2 = mask2.float()

		newim1 = mask1 * image1 + mask2 * image2
		newim2 = mask1 * image2 + mask2 * image1

		return newim1, newim2, mask1, mask2

	else:

		mask1 = torch.ones(image1.shape).to(device) # Image above spline
		for (x, y) in spline:
			mask1[y, x:] = 0
		mask2 = 1 - mask1 # Image below spline

		# print(mask1)
		# print(mask2)

		# io.imsave('temp_mask1.jpg', (mask1.cpu().numpy() * 255).astype(np.uint8))
		# io.imsave('temp_mask2.jpg', (mask2.cpu().numpy() * 255).astype(np.uint8))

		mask1 = mask1.float()
		mask2 = mask2.float()

		newim1 = mask1 * image1 + mask2 * image2
		newim2 = mask1 * image2 + mask2 * image1

		return newim1, newim2, mask1, mask2


def merge_images(image1, image2):

	assert image1.shape == image2.shape

	image1, image2 = torch.from_numpy(image1).to(device), torch.from_numpy(image2).to(device)
	image1, image2 = image1.float(), image2.float()

	vert = False
	(x, y), _ = generate_spline(image1.shape[0], vertical=vert, integer=True)
	spline = list(zip(x, y))
	hor_I1, hor_I2, hor_mask1, hor_mask2 = get_negative_image(image1, image2, spline, vertical_division=vert)

	vert = True
	(x, y), _ = generate_spline(image1.shape[0], vertical=vert, integer=True)
	spline = list(zip(x, y))
	ver_I1, ver_I2, ver_mask1, ver_mask2 = get_negative_image(image1, image2, spline, vertical_division=vert)

	return hor_I1, hor_I2, ver_I1, ver_I2, hor_mask1, hor_mask2, ver_mask1, ver_mask2


def get_negative_image_3(image1, image2, image3, spline_vert, spline_hor):

	'''
		Merges 3 images, in a 3-way splicing.

		spline_vert -> vertical spline (going from top to bottom edge)
		spline_hor -> horizontal spline (going from left to right edge)

		image1 -> class A
		image2 -> class A
		image3 -> class B

		The following 4 scenarios can arise

			A | A   	A |       	  | A     	  B
		1)  -----   2)	--| B   3)	B |--   4)	-----
			  B     	A |       	  | A   	A | A

		The corresponding operations are:

		1) spline_vert_left * image1 + spline_vert_right * image2 + spline_hor_bottom * image3
		2) spline_hor_top * image1 + spline_hor_bottom * image2 + spline_vert_right * image3
		3) spline_hor_top * image1 + spline_hor_bottom * image2 + spline_vert_left * image3
		4) spline_vert_left * image1 + spline_vert_right * image2 + spline_hor_top * image3

		Out of these, some will not have enough representation from one class
	'''

	spline_vert_left = torch.ones(image1.shape).to(device) # Image to the left of spline
	for (x, y) in spline_vert:
		spline_vert_left[y, x:] = 0
	spline_vert_right = 1 - spline_vert_left # Image to the right spline

	spline_hor_top = torch.ones(image1.shape).to(device) # Image to the top of spline
	for (x, y) in spline_hor:
		spline_hor_top[y:, x] = 0
	spline_hor_bottom = 1 - spline_hor_top # Image to the bottom of spline

	I1 = spline_hor_top * (spline_vert_left * image1 + spline_vert_right * image2) + spline_hor_bottom * image3
	I2 = spline_vert_left * (spline_hor_top * image1 + spline_hor_bottom * image2) + spline_vert_right * image3
	I3 = spline_vert_right * (spline_hor_top * image1 + spline_hor_bottom * image2) + spline_vert_left * image3
	I4 = spline_hor_bottom * (spline_vert_left * image1 + spline_vert_right * image2) + spline_hor_top * image3

	M1 = spline_hor_top * (spline_vert_left * torch.ones(image1.shape).to(device) * 32 + spline_vert_right * torch.ones(image1.shape).to(device) * 64) + spline_hor_bottom * torch.ones(image1.shape).to(device) * 192
	M2 = spline_vert_left * (spline_hor_top * torch.ones(image1.shape).to(device) * 32 + spline_hor_bottom * torch.ones(image1.shape).to(device) * 64) + spline_vert_right * torch.ones(image1.shape).to(device) * 192
	M3 = spline_vert_right * (spline_hor_top * torch.ones(image1.shape).to(device) * 32 + spline_hor_bottom * torch.ones(image1.shape).to(device) * 64) + spline_vert_left * torch.ones(image1.shape).to(device) * 192
	M4 = spline_hor_bottom * (spline_vert_left * torch.ones(image1.shape).to(device) * 32 + spline_vert_right * torch.ones(image1.shape).to(device) * 64) + spline_hor_top * torch.ones(image1.shape).to(device) * 192

	return I1, I2, I3, I4, M1, M2, M3, M4


def merge_images_3(image1, image2, image3):

	(x, y), _ = generate_spline(image1.shape[0], vertical=True, integer=True)
	spline_vert = list(zip(x, y))

	(x, y), _ = generate_spline(image1.shape[0], vertical=False, integer=True)
	spline_hor = list(zip(x, y))

	image1, image2, image3 = torch.from_numpy(image1).to(device), torch.from_numpy(image2).to(device), torch.from_numpy(image3).to(device)
	image1, image2, image3 = image1.float(), image2.float(), image3.float()

	return get_negative_image_3(image1, image2, image3, spline_vert, spline_hor)


def get_category_number(c1, c2):

	'''
		Returns the category number given two classes.
	'''

	return negative_category_dict[(c1, c2)]


for dataset_exp_name, datasets_source, datasets_target in zip(dataset_exp_names, datasets_sources, datasets_targets):
    print('running', dataset_exp_name)

    print('shared_classes: {}'.format(C))
    print('source_private_classes: {}'.format(Cs_dash))
    print('target_private_classes: {}'.format(Ct_dash))

    filenames = glob(os.path.join(server_root_path, dataset_dir, dataset_exp_name, 'source_images', 'train', '*.png'))
    savepath = os.path.join(server_root_path, dataset_dir, dataset_exp_name, 'negative_images')
    savepath_mask = os.path.join(server_root_path, dataset_dir, dataset_exp_name, 'negative_masks')

    if os.path.exists(savepath):
        os.system('rm -rf ' + savepath)
    os.mkdir(savepath)

    if os.path.exists(savepath_mask):
        os.system('rm -rf ' + savepath_mask)
    os.mkdir(savepath_mask)

    # Remove augmented files
    # filenames = [a for a in filenames if a.split('/')[-1].split('_')[0]=='category']
    L = len(filenames)

    counter = 0

    class_wise_files = {}

    if n_way_splice == 3:
        for cnum in range(num_source_classes):
            class_wise_files[str(cnum)] = []
            pattern = re.compile('_category_number_' + str(cnum) + '_dataset_' + str(datasets_source[0]) + '_')
            for fname in filenames:
                if pattern.search(fname) is not None:
                    class_wise_files[str(cnum)].append(fname)

    for i in tqdm(range(5000), desc="Processing negative images", ncols=100):  # Wrap the inner loop with tqdm

        if n_way_splice == 2:
            im1 = np.random.randint(L)
            im2 = np.random.randint(L)

            f1, f2 = filenames[im1], filenames[im2]
            c1, c2 = int(f1.split('_')[-4]), int(f2.split('_')[-4])

            while c1 == c2:
                im2 = np.random.randint(L)
                f2 = filenames[im2]
                c2 = int(f2.split('_')[-4])

            image1, image2 = io.imread(f1), io.imread(f2)

            a, b, c, d, e, f, g, h = merge_images(image1, image2)
            merged_images = [a, b, c, d]
            merged_masks = [e, f, g, h]

            cnum = get_category_number(c1, c2)

            for image, mask in zip(merged_images, merged_masks):
                counter += 1
                save_filename = 'category_' + str(cnum) + '_category_number_' + str(cnum) + '_dataset_' + str(datasets_source[0]) + '_' + str(counter) + '.png'
                save_filename_mask = 'mask_category_' + str(cnum) + '_category_number_' + str(cnum) + '_dataset_' + str(datasets_source[0]) + '_' + str(counter) + '.png'
                io.imsave(os.path.join(savepath, save_filename), image.cpu().numpy().astype(np.uint8))
                io.imsave(os.path.join(savepath_mask, save_filename_mask), (mask * 255).cpu().numpy().astype(np.uint8))

        elif n_way_splice == 3:
            # first image
            im1 = np.random.randint(L)
            f1 = filenames[im1]
            c1 = int(f1.split('_')[-4])

            # second image
            im2 = np.random.randint(len(class_wise_files[str(c1)]))
            f2 = class_wise_files[str(c1)][im2]
            c2 = c1

            while f1 == f2:
                im2 = np.random.randint(len(class_wise_files[str(c1)]))
                f2 = class_wise_files[str(c1)][im2]

            # third image
            x = np.random.randint(len(list(class_wise_files.keys())))
            c3 = int((list(class_wise_files.keys())[x]))

            while c3 == c2:
                x = np.random.randint(len(list(class_wise_files.keys())))
                c3 = int((list(class_wise_files.keys())[x]))

            im3 = np.random.randint(len(class_wise_files[str(c3)]))
            f3 = class_wise_files[str(c3)][im3]

            image1, image2, image3 = io.imread(f1), io.imread(f2), io.imread(f3)
            i1, i2, i3, i4, m1, m2, m3, m4 = merge_images_3(image1, image2, image3)
            merged_images = [i1, i2, i3, i4]
            merged_masks = [m1, m2, m3, m4]

            cnum = get_category_number(c1, c3)

            for image, mask in zip(merged_images, merged_masks):
                counter += 1
                save_filename = 'category_' + str(cnum) + '_category_number_' + str(cnum) + '_dataset_' + str(datasets_source[0]) + '_' + str(counter) + '.png'
                save_filename_mask = 'mask_category_' + str(cnum) + '_category_number_' + str(cnum) + '_dataset_' + str(datasets_source[0]) + '_' + str(counter) + '.png'
                io.imsave(os.path.join(savepath, save_filename), image.cpu().numpy().astype(np.uint8))
                io.imsave(os.path.join(savepath_mask, save_filename_mask), (mask).cpu().numpy().astype(np.uint8))

running usfda_office_31_DtoA
shared_classes: ['back_pack', 'calculator', 'keyboard', 'monitor', 'mouse', 'mug', 'bike', 'laptop_computer', 'headphones', 'projector']
source_private_classes: ['bike_helmet', 'bookcase', 'bottle', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'letter_tray', 'mobile_phone', 'paper_notebook']
target_private_classes: ['pen', 'phone', 'printer', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']


Processing negative images: 100%|███████████████████████████████| 5000/5000 [14:46<00:00,  5.64it/s]


In [11]:
dataset_dir = 'Office-31'
dataset_exp_name = 'usfda_office_31_DtoA'
images_folder_name = 'negative_images'

data_dir = os.path.join(server_root_path, dataset_dir, dataset_exp_name, images_folder_name)
temp_paths_train = os.listdir(data_dir)

temp_paths_train = [os.path.join(dataset_dir, dataset_exp_name, images_folder_name, x) for x in temp_paths_train]
imgs_path_train = np.array(temp_paths_train)

print(data_dir)
print(len(temp_paths_train))
print(len(imgs_path_train))

/content/drive/MyDrive/Office-31/usfda_office_31_DtoA/negative_images
8733
8733


In [12]:
#Save index list train
save_path = os.path.join(server_root_path, dataset_dir, dataset_exp_name, 'index_lists')
np.save(save_path + '/' + images_folder_name + '_index_list_' + 'train.npy', imgs_path_train)

In [4]:
import cv2
import glob
from skimage import io
import numpy as np
import os
from tqdm import tqdm

import torch
import torchvision

chop_distances = {}

def get_chop_distance(rotated_mat, angle):

    '''
        Returns the distance at which the rotated image should be chopped off. Used in rotateImage() function.
    '''

    global chop_distances

    if angle in chop_distances.keys():
        return chop_distances[angle]

    if angle > 0:

        x = 0
        y = 0

        while(rotated_mat[y, x, 0] == 0):
            y += 1

        chop_distances[angle] = y

        # print(chop_distances)
        return y

    else:

        x = rotated_mat.shape[1] - 1
        y = 0

        while(rotated_mat[y, x, 0] == 0):
            y += 1

        chop_distances[angle] = y

        # print(chop_distances)
        return y


def rotateImage(mat, angle):

    '''
        Rotates an image (angle in degrees) and zooms into the image to avoid crop borders.
    '''

    height, width = mat.shape[:2] # image shape has 3 dimensions
    image_center = (width/2, height/2) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape

    rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.)

    # rotation calculates the cos and sin, taking absolutes of those.
    abs_cos = abs(rotation_mat[0,0])
    abs_sin = abs(rotation_mat[0,1])

    # find the new width and height bounds
    bound_w = int(height * abs_sin + width * abs_cos)
    bound_h = int(height * abs_cos + width * abs_sin)

    # subtract old image center (bringing image back to origo) and adding the new image center coordinates
    rotation_mat[0, 2] += bound_w/2 - image_center[0]
    rotation_mat[1, 2] += bound_h/2 - image_center[1]

    # rotate image with the new bounds and translated rotation matrix
    rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h))

    H, W, c = rotated_mat.shape

    d = get_chop_distance(rotated_mat, angle)
    chopped_image = rotated_mat[d : H-d, d : W-d]

    resized_image = cv2.resize(chopped_image, (224, 224))

    return resized_image


def random_crop(image, cropshape, padsize):

    '''
        Takes a random crop from the image.
    '''

    # crop height, width, channels
    H, W, C = image.shape
    p = padsize

    # Check shapes etc.
    if type(cropshape) == int:
        cH = cW = cropshape
        assert cH <= H, 'Crop size is greater than image size'
    elif len(cropshape) == 2:
        cH, cW = cropshape
        assert cH <= H and cW <= W, 'Crop size is greater than image size'
    else:
        raise Exception('Wrong crop shape (use either int (s) or tuple (h, w))')

    if type(padsize) == int:
        pH = pW = padsize
    elif len(padsize) == 2:
        pH, pW = padsize
    else:
        raise Exception('Wrong pad shape (use either int (s) or tuple (h, w))')

    # Created padded image
    paddedimage = np.zeros((cH + 2*pH, cW + 2*pW, C), dtype = image.dtype)
    paddedimage[pH:pH+H, pW:pW+W, :] = image

    # Output image
    outimage = np.zeros((cH, cW, C), dtype = image.dtype)

    # Randomly chose a start location (this is the random step)
    startx = np.random.randint(2*padsize)
    starty = np.random.randint(2*padsize)

    # Crop (H, W, C) shaped image from start locations
    outimage = paddedimage[starty:starty+cH, startx:startx+cH, :]

    return outimage


def random_horizontal_flip(image, always_flip=True):

    '''
        Randomly flips the given image.
    '''

    if always_flip:
        r = 1 # always flip
    else:
        r = np.random.rand()

    if r >= 0:
        return np.fliplr(image), True
    else:
        return np.array(image), False


def rgb_flip(img, reorder):

    '''
        Flips the channels in the given order.
    '''

    newim = img.copy()
    newim[:, :, 0] = img[:, :, reorder[0]]
    newim[:, :, 1] = img[:, :, reorder[1]]
    newim[:, :, 2] = img[:, :, reorder[2]]

    return newim


def color_jitter(img, brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25):

    '''
        Applies color jitter augmentation (brightness, contrast, saturation, hue).
    '''

    tensorImage = torchvision.transforms.ToTensor()(img)
    pilImage = torchvision.transforms.ToPILImage(mode='RGB')(tensorImage)
    jitteredImage = torchvision.transforms.ColorJitter(brightness, contrast, saturation, hue)(pilImage)
    jitteredArray = np.array(jitteredImage)

    return jitteredArray


def count_class_distribution(images_path):

    '''
        Get the number of images in each class. Call before augmenting dataset so that each class has same number of examples.
    '''

    classdict = {}

    for im in images_path:

        a = im.split('/')[-1]
        c = int(a.split('category_number_')[1].split('_')[0])

        if c in classdict:
            classdict[c].append(im)
        else:
            classdict[c] = [im]

    return classdict


def augment_images(images_path):

    '''
        Main function to augment all images with filenames given in images_path.
    '''

    # print('\nAugment Images')

    # Rotate and augment first
    for xx in images_path:

        image = io.imread(xx)

        augmentated_images = []

        # Flip image
        image_flip, success = random_horizontal_flip(image)
        if success:
            file_name = '/'.join(xx.split('/')[:-1]) + '/' + 'flip_' +  xx.split('/')[-1]
            augmentated_images.append(file_name)
            io.imsave(file_name, image_flip)

        # Rotate Image
        for yy in [-3, -2, -1, 1, 2, 3]:
            image_rot = rotateImage(image, yy*5)
            file_name = '/'.join(xx.split('/')[:-1]) + '/' + 'rotate' + str(yy) + '_' +  xx.split('/')[-1]
            io.imsave(file_name, image_rot)

        # RGB flip
        for order in [(0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]:
            r = str(order[0])
            g = str(order[1])
            b = str(order[2])
            image_rgb_flipped = rgb_flip(image, order)
            file_name = '/'.join(xx.split('/')[:-1]) + '/' + 'rgbflip' + r + g + b + '_' +  xx.split('/')[-1]
            io.imsave(file_name, image_rgb_flipped)

        # Color Jitter
        for jj in range(5):
            image_jittered = color_jitter(image, brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05)
            # print(image_jittered.shape)
            file_name = '/'.join(xx.split('/')[:-1]) + '/' + 'jitter' + str(jj) + '_' +  xx.split('/')[-1]
            io.imsave(file_name, image_jittered)


def balance_classes(images_path):

    '''
        Balances classes with randomly cropped images.
    '''

    # print('\nBalance Classes')

    # Balance classes with random crop (only augment classes with lesser examples to create even distribution across all classes)
    # (Both in the source and target domain)

    classdict = count_class_distribution(images_path)
    maxCount = np.max([len(classdict[c]) for c in classdict])
    print('Initial Class distribution', [len(classdict[c]) for c in classdict])

    writtenImagesCount = {c:0 for c in classdict}

    for c in tqdm(classdict, desc="Balancing", ncols=100):
        n_c = len(classdict[c])

        for i in range(maxCount - n_c):
            randim = classdict[c][np.random.randint(len(classdict[c]))]
            cropnum = 1
            file_name = '/'.join(randim.split('/')[:-1]) + '/' + 'randcrop_' + str(cropnum) + '_' +  randim.split('/')[-1]
            while(True):
                if os.path.exists(file_name):
                    cropnum += 1
                    file_name = '/'.join(randim.split('/')[:-1]) + '/' + 'randcrop_' + str(cropnum) + '_' +  randim.split('/')[-1]
                else:
                    image = io.imread(randim)
                    image_randcrop = random_crop(image, 224, 20)
                    io.imsave(file_name, image_randcrop)
                    writtenImagesCount[c] += 1
                    break

    print('Added Images', [writtenImagesCount[c] for c in writtenImagesCount])


def make_augmentation_dictionary(index_list_path):

    '''
        Creates a dictionary with image filenames, and the list of their augmented images filenames.
    '''

    def get_augmentations(image_filename):

        '''
            Gets the names of the augmentation files.
        '''

        xx = image_filename

        aug_files = []

        # Flip image
        file_name = '/'.join(xx.split('/')[:-1]) + '/' + 'flip_' +  xx.split('/')[-1]
        aug_files.append(file_name)

        # Rotate Image
        for yy in [-3, -2, -1, 1, 2, 3]:
            file_name = '/'.join(xx.split('/')[:-1]) + '/' + 'rotate' + str(yy) + '_' +  xx.split('/')[-1]
            aug_files.append(file_name)

        # RGB flip
        for order in [(0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]:
            r = str(order[0])
            g = str(order[1])
            b = str(order[2])
            file_name = '/'.join(xx.split('/')[:-1]) + '/' + 'rgbflip' + r + g + b + '_' +  xx.split('/')[-1]
            aug_files.append(file_name)

        # Color Jitter
        for jj in range(5):
            file_name = '/'.join(xx.split('/')[:-1]) + '/' + 'jitter' + str(jj) + '_' +  xx.split('/')[-1]
            aug_files.append(file_name)

        return aug_files

    def get_augmentations_prefix():

        '''
            Returns the list of prefixes for augmentation images.
        '''

        aug_prefix = []

        # Flip image
        aug_prefix.append('flip')

        # Rotate Image
        for yy in [-3, -2, -1, 1, 2, 3]:
            aug_prefix.append('rotate' + str(yy))

        # RGB flip
        for order in [(0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]:
            r = str(order[0]); g = str(order[1]); b = str(order[2])
            aug_prefix.append('rgbflip' + r + g + b)

        # Color Jitter
        for jj in range(5):
            aug_prefix.append('jitter' + str(jj))

        # Randcrop
        aug_prefix.append('randcrop')

        return aug_prefix

    all_fils = np.load(index_list_path)
    pfix = get_augmentations_prefix()

    only_images = [a for a in all_fils if a.split('/')[-1].split('_')[0] not in pfix]
    print(len(only_images))

    print('sanity check 1')
    for f in tqdm(only_images):
        assert f.split('/')[-1].split('_')[0]=='category'
    print('correct')

    dictionary = {}

    print('sanity check 2')
    for f in tqdm(only_images):
        augfiles = get_augmentations(f)
        dictionary[f] = augfiles
        for fil in augfiles:
            assert fil in all_fils
    print('correct')

    return dictionary

In [5]:
from tqdm import tqdm
import os
import glob
import numpy as np

dataset_name = 'Office-31'  # 'office_31_dataset'
experiments = ['usfda_office_31_DtoA']

for dataset_exp_name in experiments:

    # Rotate and augment source images
    print('Rotation augmentation on source images')
    os.system('rm -rf /content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/source_images/train/flip*.png')
    os.system('rm -rf /content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/source_images/train/randcrop*.png')
    os.system('rm -rf /content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/source_images/train/rotate*.png')
    os.system('rm -rf /content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/source_images/train/rgbflip*.png')
    os.system('rm -rf /content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/source_images/train/jitter*.png')

    files = glob.glob('/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/source_images/train/*.png')

    # Wrap the entire file list with tqdm to track progress for all files
    for file in tqdm(files, desc="Augmenting source images", ncols=100):
        augment_images([file])  # Process each file

    # Balance source classes
    print('Class balancing on source images')
    balance_classes(files)

    # Rotate and augment target images
    print('Rotation augmentation on target images')
    os.system('rm -rf /content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/target_images/train/flip*.png')
    os.system('rm -rf /content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/target_images/train/randcrop*.png')
    os.system('rm -rf /content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/target_images/train/rotate*.png')
    os.system('rm -rf /content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/target_images/train/rgbflip*.png')
    os.system('rm -rf /content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/target_images/train/jitter*.png')

    files = glob.glob('/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/target_images/train/*.png')
    for file in tqdm(files, desc="Augmenting target images", ncols=100):
        augment_images([file])

    # Balance target classes
    print('Class balancing on target images')
    balance_classes(files)

    # # Save index list source images
    # temp_paths_train = os.listdir('/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/source_images/train/')
    # temp_paths_train = ['/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/source_images/train/' + x for x in temp_paths_train]
    # save_path = '/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/index_lists/source_images_index_list_train.npy'
    # np.save(save_path , np.array(temp_paths_train))

    # # Save index list target images
    # temp_paths_train = os.listdir('/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/target_images/train/')
    # temp_paths_train = ['/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/target_images/train/' + x for x in temp_paths_train]
    # save_path = '/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/index_lists/target_images_index_list_train.npy'
    # np.save(save_path , np.array(temp_paths_train))

    # # Make augmentation dictionary for source
    # source_list_filename = '/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/index_lists/source_images_index_list_train.npy'
    # source_augs = make_augmentation_dictionary(source_list_filename)
    # source_aug_dict_path = '/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/index_lists/source_images_aug_dict_train.npy'
    # np.save(source_aug_dict_path, source_augs)

    # # Make augmentation dictionary for target
    # target_list_filename = '/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/index_lists/target_images_index_list_train.npy'
    # target_augs = make_augmentation_dictionary(target_list_filename)
    # target_aug_dict_path = '/content/drive/MyDrive/' + dataset_name + '/' + dataset_exp_name + '/index_lists/target_images_aug_dict_train.npy'
    # np.save(target_aug_dict_path, target_augs)

    # Define the save path for both source and target images
    save_path = os.path.join('/content/drive/MyDrive', dataset_name, dataset_exp_name, 'index_lists')

    # Save index list for source images
    source_images_path = os.listdir(os.path.join('/content/drive/MyDrive', dataset_name, dataset_exp_name, 'source_images', 'train'))
    source_images_path = [os.path.join('/content/drive/MyDrive', dataset_name, dataset_exp_name, 'source_images', 'train', x) for x in source_images_path]
    np.save(os.path.join(save_path, 'source_images_index_list_train.npy'), source_images_path)

    # Save index list for target images
    target_images_path = os.listdir(os.path.join('/content/drive/MyDrive', dataset_name, dataset_exp_name, 'target_images', 'train'))
    target_images_path = [os.path.join('/content/drive/MyDrive', dataset_name, dataset_exp_name, 'target_images', 'train', x) for x in target_images_path]
    np.save(os.path.join(save_path, 'target_images_index_list_train.npy'), target_images_path)

    # Make augmentation dictionary for source
    source_list_filename = os.path.join('/content/drive/MyDrive', dataset_name, dataset_exp_name, 'index_lists', 'source_images_index_list_train.npy')
    source_augs = make_augmentation_dictionary(source_list_filename)
    source_aug_dict_path = os.path.join('/content/drive/MyDrive', dataset_name, dataset_exp_name, 'index_lists', 'source_images_aug_dict_train.npy')
    np.save(source_aug_dict_path, source_augs)

    # Make augmentation dictionary for target
    target_list_filename = os.path.join('/content/drive/MyDrive', dataset_name, dataset_exp_name, 'index_lists', 'target_images_index_list_train.npy')
    target_augs = make_augmentation_dictionary(target_list_filename)
    target_aug_dict_path = os.path.join('/content/drive/MyDrive', dataset_name, dataset_exp_name, 'index_lists', 'target_images_aug_dict_train.npy')
    np.save(target_aug_dict_path, target_augs)


Rotation augmentation on source images


Augmenting source images: 100%|███████████████████████████████████| 279/279 [03:04<00:00,  1.51it/s]


Class balancing on source images
Initial Class distribution [10, 10, 9, 19, 10, 7, 18, 21, 11, 20, 21, 10, 14, 11, 12, 13, 13, 14, 27, 9]


Balancing: 100%|████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.03it/s]


Added Images [17, 17, 18, 8, 17, 20, 9, 6, 16, 7, 6, 17, 13, 16, 15, 14, 14, 13, 0, 18]
Rotation augmentation on target images


  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return func(*args, **kwargs)
  return

Class balancing on target images
Initial Class distribution [90, 75, 100, 99, 99, 96, 64, 100, 94, 82, 100, 99, 98, 95, 93, 100, 98, 92, 94, 100, 99]


Balancing: 100%|████████████████████████████████████████████████████| 21/21 [00:03<00:00,  5.86it/s]


Added Images [10, 25, 0, 1, 1, 4, 36, 0, 6, 18, 0, 1, 2, 5, 7, 0, 2, 8, 6, 0, 1]
279
sanity check 1


100%|██████████| 279/279 [00:00<00:00, 104334.06it/s]


correct
sanity check 2


100%|██████████| 279/279 [00:03<00:00, 91.70it/s] 


correct
1967
sanity check 1


100%|██████████| 1967/1967 [00:00<00:00, 350133.51it/s]


correct
sanity check 2


100%|██████████| 1967/1967 [01:31<00:00, 21.40it/s]


correct


In [3]:
import os
import torch
from glob import glob
from natsort import natsorted

# Configuration Settings
settings = {}

# Supervised Training Toggle
settings['running_supervised'] = True

# Training Parameters
settings['start_iter'] = 1
settings['max_iter'] = 10
settings['val_after'] = 5

# Label Set Relationships
settings['C'] = ['back_pack', 'calculator', 'keyboard', 'monitor', 'mouse', 'mug', 'bike', 'laptop_computer', 'headphones', 'projector']
settings['Cs_dash'] = ['bike_helmet', 'bookcase', 'bottle', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'letter_tray', 'mobile_phone', 'paper_notebook']
settings['Ct_dash'] = ['pen', 'phone', 'printer', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']

settings['num_C'] = len(settings['C'])
settings['num_Cs_dash'] = len(settings['Cs_dash'])
settings['num_Ct_dash'] = len(settings['Ct_dash'])
settings['num_Cs'] = settings['num_C'] + settings['num_Cs_dash']
settings['num_Ct'] = settings['num_C'] + settings['num_Ct_dash']

# Batch Sizes and Sample Counts
settings['batch_size'] = 64
settings['num_positive_samples'] = 32
settings['num_negative_samples'] = 32
settings['num_positive_images'] = settings['batch_size']
settings['num_negative_images'] = settings['batch_size']

# Model Parameters
settings['cnn_to_use'] = 'resnet50'
settings['Fs_dims'] = 256
settings['softmax_temperature'] = 1
settings['online_augmentation_90_degrees'] = True  # Augmentation during training
settings['val_aug_imgs_mean_before_softmax'] = False
settings['val_aug_imgs_mean_after_softmax'] = True

# Weights and Experiment Name
settings['load_weights'] = False
settings['load_exp_name'] = 'None'
settings['exp_name'] = 'usfda_office_31_DtoA'

# Optimizers for Losses
settings['optimizer'] = {
    'classification': ['M', 'Fs', 'Cs', 'Cn'],
    'pos_img_recon': ['Fs', 'G'],
    'pos_sample_recon': ['Fs', 'G'],
    'logsoftmax': ['Fs'],
}

# Loss Usage
settings['use_loss'] = {
    'classification': True,
    'pos_img_recon': True,
    'pos_sample_recon': True,
    'logsoftmax': True,
}

settings['losses_after_enough_iters'] = ['logprob', 'logsoftmax', 'pos_sample_recon']
settings['classification_weight'] = [1, 0.2]  # Hyperparameter alpha

# Model Training Toggles
settings['to_train'] = {
    'M': False,  # Frozen ResNet-50
    'Fs': True,
    'Ft': False,
    'G': True,
    'Cs': True,
    'Cn': True,
}

# Learning Rate and Device Configuration
# Learning Rate and Device Configuration
settings['lr'] = 1e-4
settings['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    settings['gpu'] = torch.cuda.current_device()
    torch.cuda.set_device(settings['gpu'])
else:
    settings['gpu'] = None  # Set to None when using CPU

# Print device information for verification
print(f"Using device: {settings['device']}")

settings['dataset_exp_name'] = 'usfda_office_31_DtoA'

settings['weights_path'] = os.path.join(server_root_path, 'Office-31', settings['dataset_exp_name'], 'weights')
settings['summaries_path'] = os.path.join(server_root_path, 'Office-31', settings['dataset_exp_name'], 'summaries')

# Loading Pretrained Weights
if settings['load_weights']:
    best_weights = natsorted(glob(os.path.join(settings['weights_path'], settings['load_exp_name'], '*.pth')))[-1]
    settings['load_weights_path'] = best_weights

# Dataset and Paths
settings['dataset_path'] = os.path.join(server_root_path, 'Office-31', settings['dataset_exp_name'], 'index_lists')
settings['negative_data_path'] = os.path.join(server_root_path, 'Office-31', settings['dataset_exp_name'], 'negative_images')
settings['negative_mask_path'] = os.path.join(server_root_path, 'Office-31', settings['dataset_exp_name'], 'negative_masks')

# Print settings to verify
print(settings)

Using device: cuda
{'running_supervised': True, 'start_iter': 1, 'max_iter': 10, 'val_after': 5, 'C': ['back_pack', 'calculator', 'keyboard', 'monitor', 'mouse', 'mug', 'bike', 'laptop_computer', 'headphones', 'projector'], 'Cs_dash': ['bike_helmet', 'bookcase', 'bottle', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'letter_tray', 'mobile_phone', 'paper_notebook'], 'Ct_dash': ['pen', 'phone', 'printer', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can'], 'num_C': 10, 'num_Cs_dash': 10, 'num_Ct_dash': 11, 'num_Cs': 20, 'num_Ct': 21, 'batch_size': 64, 'num_positive_samples': 32, 'num_negative_samples': 32, 'num_positive_images': 64, 'num_negative_images': 64, 'cnn_to_use': 'resnet50', 'Fs_dims': 256, 'softmax_temperature': 1, 'online_augmentation_90_degrees': True, 'val_aug_imgs_mean_before_softmax': False, 'val_aug_imgs_mean_after_softmax': True, 'load_weights': False, 'load_exp_name': 'None', 'exp_name': 'usfda_office_

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import pdb
import torch
from torch.autograd import Variable
from torchvision import models

Fs_dims = settings['Fs_dims']
cnn_to_use = settings['cnn_to_use']

class CustomResNet(nn.Module):

    def __init__(self):

        super(CustomResNet, self).__init__()

        temp_resnet = models.resnet50(pretrained=True)
        self.features = nn.Sequential(*[x for x in list(temp_resnet.children())[:-1]]) # Upto the avgpool layer

    def forward(self, x):

        feats = self.features(x)
        return feats.view((x.shape[0], 2048))


class modnet(nn.Module):

    def __init__(self, num_C, num_Cs_dash, num_Ct_dash, cnn=cnn_to_use, additional_components=[]):

        super(modnet, self).__init__()

        # Frozen initial conv layers
        if cnn=='resnet50':
            self.M = CustomResNet()
        else:
            raise NotImplementedError('Not implemented for ' + str(cnn))

        self.Fs = nn.Sequential(
            nn.Linear(2048,1024),
            nn.ELU(),
            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.ELU(),
            nn.Linear(1024,Fs_dims),
            nn.ELU(),
            nn.Linear(Fs_dims, Fs_dims),
            nn.BatchNorm1d(Fs_dims),
            nn.ELU(),
        )

        self.Ft = nn.Sequential(
            nn.Linear(2048,1024),
            nn.ELU(),
            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.ELU(),
            nn.Linear(1024,Fs_dims),
            nn.ELU(),
            nn.Linear(Fs_dims, Fs_dims),
            nn.BatchNorm1d(Fs_dims),
            nn.ELU(),
        )

        self.G = nn.Sequential(
            nn.Linear(Fs_dims,1024),
            nn.ELU(),
            nn.Linear(1024,1024),
            nn.BatchNorm1d(1024),
            nn.ELU(),
            nn.Linear(1024,2048),
            nn.ELU(),
            nn.Linear(2048, 2048),
        )

        self.Cs = nn.Sequential(
            nn.Linear(Fs_dims, num_C + num_Cs_dash)
        )

        # Negative class classifier. Change this to vary the size of the negative class classifier.
        n = settings['num_C'] + settings['num_Cs_dash']
        num_negative_classes = int(n*(n-1)/2)
        # num_negative_classes = 150

        self.Cn = nn.Sequential(
            nn.Linear(Fs_dims, num_negative_classes)
        )

        self.components = {
            'M': self.M,
            'Fs': self.Fs,
            'Ft': self.Ft,
            'G': self.G,
            'Cs': self.Cs,
            'Cn': self.Cn,
        }

    def forward(self, x, which_fext='original'):
        raise NotImplementedError('Implemented a custom forward in train loop')


def no_param(model):
    param = 0
    for p in list(model.parameters()):
        n=1
        for i in list(p.size()):
            n*= i
        param += n
    return param


# if __name__=='__main__':
#     raise NotImplementedError('Please check README.md for execution details')

In [5]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
from torchvision import transforms, utils
import torch
import os

class TemplateDataset(Dataset):

    def __init__(self, view_params_set, transform=None, random_choice=True, online_augmentation_90_degrees=None):

        self.view_params_set = view_params_set
        self.transform = transform
        self.data = {}
        self.random_choice = random_choice
        self.online_augmentation_90_degrees = online_augmentation_90_degrees

    def __len__(self):
        if self.random_choice:
            return 1000000
        else:
            return len(self.view_params_set)

    def get_random_90_degree_augmentation(self, img):
        assert (img.shape == (224, 224, 3)) or (img.shape == (32, 32, 3))

        if type(img) == np.ndarray:
            img = transforms.ToPILImage()(img)

        angle = np.random.choice([-90, 90])
        angle = int(angle)
        new_img = transforms.functional.rotate(img, angle)

        return new_img

    def __getitem__(self, idx):

        if self.random_choice:
            idx = np.random.choice(np.arange(len(self.view_params_set)), 1)[0]
        img_name = os.path.join(server_root_path, self.view_params_set[idx])
        image = io.imread(img_name)
        image_cp = image
        if self.online_augmentation_90_degrees is None:
            if settings['online_augmentation_90_degrees']:
                image_cp = self.get_random_90_degree_augmentation(image_cp)
        else:
            if self.online_augmentation_90_degrees:
                image_cp = self.get_random_90_degree_augmentation(image_cp)
        res = transforms.Compose([transforms.ToTensor()])
        image_cp = res(image_cp)
        if self.transform:
            image = self.transform(image)
        norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        image = norm(image)
        self.data['img'] = torch.tensor(image)
        # print img_name
        # self.data['label'] = min(config.num_classes_known, int(img_name.split('_')[-4]))
        self.data['label'] = int(img_name.split('_')[-4])
        # print 'label', self.data['label']
        self.data['raw'] = image_cp
        self.data['filename'] = img_name

        return self.data.copy()

In [6]:
import torch
from torch.autograd import Variable
from torch.distributions.multivariate_normal import MultivariateNormal
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from natsort import natsorted

def PCA(data_pca, k):
    data_mean = torch.mean(data_pca,0)
    data_pca = data_pca - data_mean.expand_as(data_pca)
    U,S,V = torch.svd(torch.t(data_pca))
    Components = torch.mm(data_pca,U[:,:k])
    return Components

def Get_Full_Data(loader, net_G, which_fext = 'original', vgg_feats = True, which_level_features='default', use_numpy=False):

    label_list = []
    feature_list = []
    common_embedding_list = []
    prediction_list = []
    vggrec_list = []
    vggfeats_list = []

    L = len(loader)
    loader = enumerate(loader)
    i = 0

    with torch.no_grad():
        while(True):
            try:
                data = loader.next()[1]
                i+=1
                if i%10==0:
                    print(str(i) + ' / ' + str(L))
                # if i>320:
                #     break
            except:
                break

            x = Variable(data['img'][:, :3, :, :]).to(settings['device']).float()

            if which_level_features=='default':
                prediction, feature, vggrec, vggfeats = net_G.forward(x, which_fext=which_fext)
            else:
                prediction, feature, vggrec, vggfeats = net_G.forward_and_get_features(x, which_fext=which_fext, which_level_features=which_level_features)
                feature = feature.cpu()

            if vgg_feats:
                vggrec_list.append(vggrec.detach())
                vggfeats_list.append(vggfeats.detach())
            feature_list.append(feature.detach())
            prediction_list.append(prediction.detach())
            label = Variable(torch.LongTensor(data['label'])).to(settings['device']).float()
            label_list.append(label.detach())

        if use_numpy:
            data_x = np.concatenate(feature_list, 0)
            data_y = np.concatenate(label_list, 0)
            data_pred = np.concatenate(prediction_list, 0)
        else:
            data_x = torch.cat(feature_list, 0)
            data_y = torch.cat(label_list, 0)
            data_pred = torch.cat(prediction_list, 0)

        if vgg_feats:
            if use_numpy:
                data_vggrec = np.concatenate(vggrec_list, 0)
                data_vggfeats = np.concatenate(vggfeats_list, 0)
            else:
                data_vggrec = torch.cat(vggrec_list, 0)
                data_vggfeats = torch.cat(vggfeats_list, 0)
    if vgg_feats:
        return data_x, data_y, data_pred, data_vggrec, data_vggfeats
    else:
        return data_x, data_y, data_pred, None, None


def Get_Cov_Mean(features):
    # Features = (D, N)
    features_mean = torch.mean(features,dim=1)
    mean_cen_feat = features - features_mean[:, None]
    COV = 1.0/(mean_cen_feat.size(1)-1) * mean_cen_feat.mm(mean_cen_feat.t())
    MEAN = features_mean
    eps = 1e-3 * torch.eye(COV.shape[0]).to(settings['device'])
    return COV+eps, MEAN

def get_mu_sig(feature_trans, labels, num_classes, only_return = False):
    cov = []
    std = []
    mean = []
    boundary = []
    mean_value = []
    eig_dir = []
    COV, MEAN = Get_Cov_Mean(feature_trans.t())
    # print feature_trans.t().shape
    u,s,v = torch.svd(COV)
    eigen_direction_big_circle = torch.matmul(v, torch.sqrt(s))
    # here range = number of categories. hence, take care to change accordingly
    for i in tqdm(range(num_classes)):
        new_ft = feature_trans[labels==i]
        features = new_ft.t()
        # print features.shape
        tmp_cov, tmp_mean = Get_Cov_Mean(features)
        std.append(torch.sqrt(torch.diag(tmp_cov)))
        cov.append(tmp_cov)
        mean.append(tmp_mean)
        # print(tmp_cov)
        # print([float(tmp_cov[i, i]) for i in range(tmp_cov.shape[0])])
        newd = MultivariateNormal(tmp_mean, tmp_cov)
        llh = newd.log_prob(tmp_mean)
        mean_value.append(llh)
        u,s,v = torch.svd(tmp_cov)
        eigen_direction = torch.matmul(v, torch.sqrt(s))
        eig_dir.append(eigen_direction)

    data = {}
    for i in range(num_classes):
        # print mean[i], cov[i], eig_dir[i]
        data[str(i)] = [mean[i].data.cpu().numpy(), cov[i].data.cpu().numpy(), eig_dir[i].data.cpu().numpy()]
    data['full_circle'] = [MEAN.data.cpu().numpy(), COV.data.cpu().numpy(), eigen_direction_big_circle]

    if only_return:
        return data
    else:
        np.save('mu_sigmanew', data)
        return data

In [7]:
!pip install tensorboardX

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/101.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━[0m [32m92.2/101.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import sys
import os
from torchvision import transforms, utils
from glob import glob
import time
import cv2
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from torch.distributions.multivariate_normal import MultivariateNormal
import time
from tqdm import tqdm
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
import pdb
from sklearn.preprocessing import StandardScaler
from scipy.stats import norm
from scipy.interpolate import make_interp_spline, BSpline
from skimage import io
from torchvision.transforms import ToTensor

class TrainerG():

    def __init__(self, network, optimizer, exp_name, index_lists, settings):
        # print('function __init__')

        # Set the network and optimizer
        self.network = network
        self.to_train = settings['to_train']

        # Optimizers to use
        self.optimizer = optimizer
        self.which_optimizer = list(sorted(self.optimizer.keys()))
        print('\noptimizers: ' + str(self.which_optimizer) + '\n')

        # Save the settings
        self.settings = settings

        # Initialize the val and train writers
        self.val_writer = SummaryWriter(os.path.join(server_root_path, 'Office-31', exp_name, 'summaries', 'logdir_val'))
        self.train_writer = SummaryWriter(os.path.join(server_root_path, 'Office-31', exp_name, 'summaries', 'logdir_train'))

        # Extract commonly used settings
        self.batch_size = settings['batch_size']
        self.current_iteration = settings['start_iter']

        # Get the index lists
        [index_list_path_train_source, index_list_path_val_source, _, _, _, _] = index_lists
        self.index_list_train_source = np.load(index_list_path_train_source)
        self.index_list_val_source = np.load(index_list_path_val_source)
        self.index_list_train_negative = glob(self.settings['negative_data_path'] + '/*.png')

        # Get number of classes
        self.num_C = settings['num_C']
        self.num_Cs_dash = settings['num_Cs_dash']
        self.num_Ct_dash = settings['num_Ct_dash']
        self.num_Cs = settings['num_Cs']
        self.num_Ct = settings['num_Ct']

        # Initialize data loaders
        self.get_all_dataloaders()

        self.mu_sigma = None
        self.mu_sigma_distribution = None

        # Mu Sigma is calculated at each validation iteration. For the first iteration, no mu sigmas exist, so don't load unless start iteration is higher.
        if (self.current_iteration >= self.settings['val_after']):

            self.recalculate_mu_sigma()


    def get_mu_sigma_threshold(self, mu_sigma, mu_sigma_distro, maximum=True):
        # print('function get_mu_sigma_threshold')

        MU, COV = torch.from_numpy(mu_sigma[0]).to(device), torch.from_numpy(mu_sigma[1]).to(device)
        u,s,v = torch.svd(COV)
        if maximum:
            eigen_direction = v[:, 0]
        else:
            eigen_direction = torch.matmul(v, torch.sqrt(s))

        threshvec = MU + 3 * eigen_direction
        return mu_sigma_distro.log_prob(threshvec)


    def get_histogram(self, features, bin_size=0.01, normalize=True):
        # print('function get_histogram')

        F = [int(float(f)/bin_size) for f in features]
        D = {}
        for f in F:
            if f in D: D[f] += 1
            else: D[f] = 1
        x = np.array([float(k*bin_size) for k in sorted(D.keys())])
        y = np.array([float(D[f]) for f in sorted(D.keys())])

        if normalize:
            y = y / np.sum(y)

        return x, y


    def get_all_dataloaders(self):
        # print('function get_all_dataloaders')

        dataset_train = TemplateDataset(self.index_list_train_source, transform=transforms.Compose([transforms.ToTensor()]), random_choice=False)
        self.loader_train = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=2)

        dataset_source_val = TemplateDataset(self.index_list_val_source, transform=transforms.Compose([transforms.ToTensor()]), random_choice=False)
        self.loader_source_val = DataLoader(dataset_source_val, batch_size=self.batch_size, shuffle=True, num_workers=2)

        dataset_source_train_negative = TemplateDataset(self.index_list_train_negative, transform=transforms.Compose([transforms.ToTensor()]), random_choice=False)
        self.loader_train_negative = DataLoader(dataset_source_train_negative, batch_size=self.batch_size, shuffle=True, num_workers=2)


    def get_loss(self, which_loss):
        # print('function get_loss')

        if which_loss == 'classification':
            outs = torch.cat([self.features['Cs'], self.features['Cn']], dim=-1)
            topK = self.settings['num_positive_images']
            positive_image_loss = nn.CrossEntropyLoss(reduction='mean')(outs[:topK], self.gt[:topK].long())
            negative_image_loss = nn.CrossEntropyLoss(reduction='mean')(outs[topK:], self.gt[topK:].long())
            w1, w2 = self.settings['classification_weight']
            loss = w1 * positive_image_loss + w2 * negative_image_loss

        elif which_loss == 'pos_img_recon':
            topK = self.settings['num_positive_samples']
            loss = torch.mean( torch.sum(torch.pow(self.features['M'][:topK] - self.features['G'][:topK], 2), dim=-1) )

        elif which_loss == 'pos_sample_recon':
            topK = self.settings['num_positive_samples']
            loss = torch.mean( torch.sum( torch.pow(self.features['Fs_sample'][:topK] - self.features['Fs_sample_recon'][:topK], 2), dim=-1 ) )

        elif which_loss == 'logsoftmax': # Increase the class confidence by decreasing the distance
            # This is L_p defined in the paper. While implementing in PyTorch, we decrease the distance to the corresponding cluster,
            # which has the same effect as increasing the relative posterior confidence.
            topK = self.settings['num_positive_images']
            malhanobis_squared_mat = torch.zeros((topK, self.num_C + self.num_Cs_dash)).to(device)
            for c in range(self.num_C + self.num_Cs_dash):
                mu = torch.from_numpy(self.mu_sigma[str(c)][0]).to(device)
                malhanobis_squared_distance = (self.mu_sigma_distribution[str(c)].log_prob(mu) - self.mu_sigma_distribution[str(c)].log_prob(self.features['Fs'][:topK])) ** 2
                malhanobis_squared_mat[:, c] = malhanobis_squared_distance
            norm_malhanobis_mat = malhanobis_squared_mat / torch.sum(malhanobis_squared_mat, dim=-1, keepdim=True)
            corr_conf = torch.zeros((self.settings['batch_size'],)).to(device)
            for i in range(self.settings['batch_size']):
                corr_conf[i] = norm_malhanobis_mat[i, self.gt[i]]
            loss = torch.mean( torch.log(corr_conf), dim=0 )

        else:
            raise NotImplementedError('Not implemented loss function ' + str(which_loss))

        self.summary_dict['loss/' + str(which_loss)] = loss.data.cpu().numpy()
        return loss


    def recalculate_mu_sigma(self):
        # print('function recalculate_mu_sigma')
        print('\nRecalculating mu sigma\n')
        self.mu_sigma = self.get_mu_sigma()
        self.mu_sigma_distribution = {}

        for key in self.mu_sigma:
            mu, sigma = torch.from_numpy(self.mu_sigma[key][0]).to(device), torch.from_numpy(self.mu_sigma[key][1]).to(device)
            self.mu_sigma_distribution[key] = torch.distributions.MultivariateNormal(mu, sigma)

        self.mu_sigma_threshold = []
        for key in range(self.num_C + self.num_Cs_dash):
            self.mu_sigma_threshold.append( self.get_mu_sigma_threshold(self.mu_sigma[str(key)], self.mu_sigma_distribution[str(key)]) )
        # print(self.mu_sigma_threshold)
        self.mu_sigma_threshold = torch.FloatTensor(self.mu_sigma_threshold).view((1, self.num_C + self.num_Cs_dash)).to(device)


    def loss(self):
        # print('function loss')

        # ==================================
        # ====== Accuracy over images ======
        # ==================================
        concat_outputs_img = torch.cat([self.features['Cs'], self.features['Cn']], dim=-1)
        concat_softmax_img = F.softmax(concat_outputs_img/self.settings['softmax_temperature'], dim=-1)
        pred_classes_img = torch.argmax(concat_softmax_img, dim=-1)
        pos = self.settings['num_positive_images']

        # Source Accuracy (Positive images)
        pred_classes_pos_img = pred_classes_img[:pos]
        gt_classes_pos_img = self.gt[:pos]
        if len(pred_classes_pos_img) != 0:
            source_acc_pos_img = (pred_classes_pos_img.float() == gt_classes_pos_img.float()).float().mean()
            self.summary_dict['acc/source_acc_pos_img'] = source_acc_pos_img

        # Source Accuracy (Negative images)
        pred_classes_neg_img = pred_classes_img[pos:]
        gt_classes_neg_img = self.gt[pos:]
        if len(pred_classes_neg_img) != 0:
            source_acc_neg_img = (pred_classes_neg_img.float() == gt_classes_neg_img.float()).float().mean()
            source_binary_acc_neg_img = (pred_classes_neg_img.float() >= self.num_Cs).float().mean()
            self.summary_dict['acc/source_acc_neg_img'] = source_acc_neg_img
            self.summary_dict['acc/source_binary_acc_neg_img'] = source_binary_acc_neg_img

        # ===================================
        # ====== Accuracy over samples ======
        # ===================================
        if self.current_iteration >= self.settings['val_after']:
            concat_outputs_sample = torch.cat([self.features['Cs_sample'], self.features['Cn_sample']], dim=-1)
            concat_softmax_sample = F.softmax(concat_outputs_sample/self.settings['softmax_temperature'], dim=-1)
            pred_classes_sample = torch.argmax(concat_softmax_sample, dim=-1)
            pred_classes_sample = pred_classes_sample.view((pred_classes_sample.shape[0],))
            pos = self.settings['num_positive_samples']

            # Sample Accuracy (Positive samples)
            pred_classes_pos_samples = pred_classes_sample[:pos]
            gt_classes_pos_samples = self.gt_sample[:pos]
            source_acc_pos_samples = (pred_classes_pos_samples.float() == gt_classes_pos_samples.float()).float().mean()
            self.summary_dict['acc/source_acc_pos_samples'] = source_acc_pos_samples

            # # Sample Accuracy (Negative samples)
            # pred_classes_neg_samples = pred_classes_sample[pos:]
            # gt_classes_neg_samples = self.gt_sample[pos:]
            # source_acc_neg_samples = (pred_classes_neg_samples.float() == gt_classes_neg_samples.float()).float().mean()
            # self.summary_dict['acc/source_acc_neg_samples'] = source_acc_neg_samples

        if self.phase == 'train':

            # ====== BACKPROP LOSSES ======
            enough_iters = (self.current_iteration >= self.settings['val_after'])
            l = len(self.which_optimizer)
            current_loss = self.which_optimizer[self.current_iteration%l]

            if self.settings['use_loss'][current_loss] and self.backward:

                print('\nApplying loss ' + str(self.which_optimizer[self.current_iteration%l]))

                if current_loss in self.settings['losses_after_enough_iters']:
                    if self.current_iteration >= self.settings['val_after']:
                        # print('{} >= {}'.format(self.current_iteration, self.settings['val_after']))
                        self.optimizer[self.which_optimizer[self.current_iteration%l]].zero_grad()
                        loss = self.get_loss(which_loss=self.which_optimizer[self.current_iteration%l])
                        self.summary_dict['loss/' + str(self.which_optimizer[self.current_iteration%l])] = loss.cpu().detach().numpy()
                        loss.backward()
                        self.optimizer[self.which_optimizer[self.current_iteration%l]].step()
                else:
                    self.optimizer[self.which_optimizer[self.current_iteration%l]].zero_grad()
                    loss = self.get_loss(which_loss=self.which_optimizer[self.current_iteration%l])
                    loss.backward()
                    self.optimizer[self.which_optimizer[self.current_iteration%l]].step()

        self.current_iteration += 1


    # def get_negative_images(self, num_images):
    #     # print('function get_negative_images')

    #     self.loader_train_negative

    #     return images, gt


    def get_sample_embeddings(self, num_samples):
        # print('function get_sample_embeddings')

        sample_size = 10000

        # Positive samples
        N = self.settings['num_positive_samples']
        randclasses = torch.randint(0, self.num_C + self.num_Cs_dash, (N,)).to(device)
        pos_samples = [self.mu_sigma_distribution[str(int(c))].sample((1,)) for c in randclasses]
        pos_gt = randclasses.clone().float()
        pos_samples = torch.cat(pos_samples, dim=0).to(device)

        return pos_samples, pos_gt



    def forward(self):
        # print('function forward')

        self.gt = Variable(torch.LongTensor(self.data['label'])).to(device).long()
        img_source = Variable(self.data['img'][:, :3, :, :]).to(device).float()
        self.gt_neg = Variable(torch.LongTensor(self.data_neg['label'])).to(device).long() + self.num_Cs
        img_neg_source = Variable(self.data_neg['img'][:, :3, :, :]).to(device).float()

        self.gt = torch.cat([self.gt, self.gt_neg], dim=0)

        images = torch.cat([img_source, img_neg_source], dim=0) # Concatenate positive and negative images

        self.features = {}

        with torch.no_grad():
            self.features['M'] = self.network.M(images)
        self.features['Fs'] = self.network.Fs(self.features['M'])
        self.features['G'] = self.network.G(self.features['Fs'])
        self.features['Cs'] = self.network.Cs(self.features['Fs'])
        self.features['Cn'] = self.network.Cn(self.features['Fs'])

        if self.current_iteration >= self.settings['val_after']:

            if (self.mu_sigma_distribution == None) or (self.current_iteration % self.settings['val_after'] == 0):
                self.recalculate_mu_sigma()

            self.features['Fs_sample'], self.gt_sample = self.get_sample_embeddings((images.shape[0],)) # top half - positive, bottom half - negative
            self.features['G_sample'] = self.network.G(self.features['Fs_sample'])
            self.features['Fs_sample_recon'] = self.network.Fs(self.features['G_sample'])
            self.features['Cs_sample'] = self.network.Cs(self.features['Fs_sample'])
            self.features['Cn_sample'] = self.network.Cn(self.features['Fs_sample'])


    def train(self):
        # print('function train')

        self.phase = 'train'

        self.summary_dict = {}

        try:
            # self.data = self.dataloader_train.next()[1]
            # self.data_neg = self.dataloader_train_negative.next()[1]

            self.data = next(self.dataloader_train)[1]
            self.data_neg = next(self.dataloader_train_negative)[1]

            if self.data['img'].shape[0] < self.settings['batch_size']:
                self.dataloader_train = enumerate(self.loader_train)
                self.data = next(self.dataloader_train)[1]

            if self.data_neg['img'].shape[0] < self.settings['batch_size']:
                self.dataloader_train_negative = enumerate(self.loader_train_negative)
                self.data_neg = next(self.dataloader_train_negative)[1]
        except:
            self.dataloader_train = enumerate(self.loader_train)
            self.data = next(self.dataloader_train)[1]
            self.dataloader_train_negative = enumerate(self.loader_train_negative)
            self.data_neg = next(self.dataloader_train_negative)[1]

        self.forward()
        self.loss()

        return self.summary_dict['acc/source_acc_pos_img']


    def val(self):
        # print('function val')

        self.phase = 'val'

        self.summary_dict = {}

        self.forward()

        self.loss()


    def log_errors(self, phase, iteration=None):
        # print('function log_errors')

        print('log errors phase: ' + str(phase) + '\n')
        print(self.summary_dict.keys())

        for x in list(sorted(self.summary_dict.keys())):
            print(x + ' : ' + str(float(self.summary_dict[x])))

            if phase == 'val':
                self.val_writer.add_scalar(x, self.summary_dict[x], self.current_iteration)
            elif phase == 'train':
                self.train_writer.add_scalar(x, self.summary_dict[x], self.current_iteration)


    def set_mode_val(self):
        # print('function set_mode_val')

        self.network.eval()
        self.backward = False
        for p in self.network.parameters():
            p.requires_grad = False
            p.volatile = True


    def set_mode_train(self):
        # print('function set_mode_train')

        self.network.train()
        self.backward = True
        for p in self.network.parameters():
            p.requires_grad = True
            p.volatile = False

        for comp in self.settings['to_train']:
            if self.settings['to_train'][comp] == False:
                self.network.components[comp].eval()
                for p in self.network.components[comp].parameters():
                    p.requires_grad = False
                    p.volatile = True


    def val_over_val_set(self):
        # print('function val_over_val_set')

        with torch.no_grad():

            enough_iters = (self.current_iteration >= self.settings['val_after'])

            self.summary_dict = {}

            dataset_source_val = TemplateDataset(self.index_list_val_source, transform=transforms.Compose([transforms.ToTensor()]), random_choice=False)
            dataloader_source = DataLoader(dataset_source_val, batch_size=self.batch_size, shuffle=True, num_workers=2)

            # ----------------------
            # Source validation Data
            # ----------------------

            print('\nValidating on source validation data')

            num_C = self.num_C
            num_Cs_dash = self.num_Cs_dash
            num_Ct_dash = self.num_Ct_dash
            num_Cs = self.num_Cs
            num_Ct = self.num_Ct

            with torch.no_grad():

                classes = list(range(num_C+num_Cs_dash))

                avg_acc = {c:0 for c in classes}
                avg_count = {c:0 for c in classes}

                idx = -1

                for data in tqdm(dataloader_source):
                    idx += 1
                    x = Variable(data['img'][:, :3, :, :]).to(self.settings['device']).float()
                    labels_source = Variable(data['label']).to(self.settings['device'])

                    M = self.network.components['M'](x)
                    Fs = self.network.components['Fs'](M)
                    G = self.network.components['G'](Fs)
                    Cs = self.network.components['Cs'](Fs)
                    Cn = self.network.components['Cn'](Fs)

                    concat_outputs = torch.cat([Cs, Cn], dim=-1)
                    concat_softmax = F.softmax(concat_outputs/self.settings['softmax_temperature'], dim=-1)

                    max_act, pred = torch.max(concat_softmax, dim=-1)

                    for c in classes:
                        avg_acc[c] += (pred[labels_source==c] == labels_source[labels_source==c]).float().sum()
                        avg_count[c] += pred[labels_source==c].shape[0]

                # average accuracy
                avg = 0
                num_classes = num_C + num_Cs_dash
                for c in classes:
                    if avg_count[c] == 0:
                        avg += 0
                    else:
                        avg += (float(avg_acc[c]) / float(avg_count[c]))
                avg /= float(num_classes)
                self.summary_dict['acc/source_avg'] = avg

            return self.summary_dict


    def get_mu_sigma(self): # For tracking variance of each class in the training dataset
        # print('function get_mu_sigma')

        dataset_train = TemplateDataset(self.index_list_train_source, transform=transforms.Compose([transforms.ToTensor()]), random_choice=False)
        dataloader_source = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=2)

        features_source = []
        labels_source = []

        with torch.no_grad():

            for data in tqdm(dataloader_source):
                img = data['img'][:, :3, :, :].float().to(device)
                lab = data['label'].float().to(device)

                M = self.network.components['M'](img)
                Fs = self.network.components['Fs'](M)

                features_source.append(Fs)
                labels_source.append(lab)

            features_source = torch.cat(features_source, dim=0)
            labels_source = torch.cat(labels_source, dim=0)

        mu_sigma_data = get_mu_sig(features_source, labels_source, self.num_C + self.num_Cs_dash, only_return=False)

        return mu_sigma_data

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import sys
import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from torch.utils.data import DataLoader
from time import time
from tqdm import tqdm
import subprocess
import warnings
from torchvision import transforms
from tensorboardX import SummaryWriter
import shutil

warnings.simplefilter("ignore", UserWarning)

# ======================= SANITY CHECK ======================= #

assert settings['running_supervised'], 'ERROR!! Config not set to run supervised trainer!!'
print('######## SANITY CHECK ########')
for key in sorted(settings.keys()):
	print('{}: {}'.format(key, settings[key]))

# ip = raw_input('continue? (y/n): ')
# if ip.lower() == 'y' or ip.lower() == 'yes':
# 	pass
# else:
# 	print('Decided not to execute!')
# 	exit()

# ==================== END OF SANITY CHECK ==================== #

max_val_source_acc = -10000
itt_delete = []

def main():

    print('\n Setting up data sources ...')

    # ====== DELETE PAST RUNS ======
    #torch.cuda.set_device(settings['gpu'])
    exp_name = settings['exp_name']
    # subprocess.call(["rm", "-rf", os.path.join(settings['weights_path'],exp_name)])
    # subprocess.call(["mkdir", os.path.join(settings['weights_path'],exp_name)])
    # subprocess.call(["rm", "-rf", os.path.join(settings['summaries_path'],exp_name)])
    # subprocess.call(["mkdir", os.path.join(settings['summaries_path'],exp_name)])
    # subprocess.call(["mkdir", os.path.join(settings['summaries_path'],exp_name)+"/logdir_train"])
    # subprocess.call(["mkdir", os.path.join(settings['summaries_path'],exp_name)+"/logdir_val"])

    # Define the paths for weights and summaries
    weights_path = settings['weights_path']
    summaries_path = settings['summaries_path']

    # Remove the existing directory for experiment's weights and summaries if they exist
    weights_dir = os.path.join(weights_path)
    if os.path.exists(weights_dir):
        shutil.rmtree(weights_dir)  # Delete the directory and its contents
    os.makedirs(weights_dir)  # Create a new empty directory for weights

    summaries_dir = os.path.join(summaries_path)
    if os.path.exists(summaries_dir):
        shutil.rmtree(summaries_dir)  # Delete the directory and its contents
    os.makedirs(summaries_dir)  # Create a new empty directory for summaries

    # Create subdirectories for training and validation logs
    train_log_dir = os.path.join(summaries_dir, "logdir_train")
    val_log_dir = os.path.join(summaries_dir, "logdir_val")
    os.makedirs(train_log_dir)  # Create training log directory
    os.makedirs(val_log_dir)  # Create validation log directory

    with open(os.path.join(settings['summaries_path'], 'txt'), 'w') as history_file:
      print('saving in ' + os.path.join(settings['summaries_path'], 'txt'))
      history_file.write('\n===== x ===== x =====\n')
      for key in sorted(settings.keys()):
        history_file.write('{}: {}\n'.format(key, settings[key]))

    # ====== DEFINE DATA SOURCES ======
    index_list_path_train_source = os.path.join(settings['dataset_path'], 'source_images_index_list_train.npy')
    index_list_path_val_source = os.path.join(settings['dataset_path'], 'source_images_index_list_val.npy')
    index_lists = [index_list_path_train_source, index_list_path_val_source, None, None, None, None]

    # ====== CREATE NETWORK ======
    print('\n Building network ...')
    network = modnet(settings['num_C'], settings['num_Cs_dash'], settings['num_Ct_dash'], cnn=settings['cnn_to_use']).to(device)

    # Load weights
    if settings['load_weights']:
      dict_to_load = torch.load(settings['load_weights_path'])
      for component in dict_to_load:
        network.components[component].load_state_dict(dict_to_load[component])

    # ====== DEFINE OPTIMIZERS ======
    print('\n Setting up optimizers ...')
    optimizer = {}

    for key in settings['use_loss']:
      if settings['use_loss'][key]:
        to_train = []
        for comp in settings['optimizer'][key]:
          if settings['to_train'][comp]:
            to_train.append({'params': network.components[comp].parameters(), 'lr':settings['lr']})#, 'momentum':0.1, 'nesterov':True, 'weight_decay':0.1})
        optimizer[key] = optim.Adam(params = to_train)

    def trainval(network, optimizer, exp_name, index_lists, settings):

      global least_val_loss
      global itt_delete

      train_iter = settings['start_iter']

      trainer_G = TrainerG(network, optimizer, exp_name, index_lists, settings)

      train_acc_list = []

      while True:

        print ("\n----------- train_iter " + str(train_iter) + ' -----------\n')
        trainer_G.set_mode_train()
        acc_gen = trainer_G.train()
        trainer_G.log_errors('train')
        train_acc_list.append(acc_gen.data.cpu().numpy())

        if train_iter%settings['val_after'] == 0:

          print('validating')

          trainer_G.set_mode_val()
          min_val_flag=test(trainer_G)

          print('min_val_flag', min_val_flag)

          if(min_val_flag):
            print ("Saving - iteration", train_iter)
            dict_to_save = {component:network.components[component].cpu().state_dict() for component in network.components}

            # dict_to_save = {
            # 	'vgg': network.vgg.cpu().state_dict(),
            # 	'feature_ext': network.feature_ext.cpu().state_dict(),
            # 	'classifier_known': network.classifier_known.cpu().state_dict(),
            # 	'dec_feature_ext': network.dec_feature_ext.cpu().state_dict(),
            # 	'classifier_unknown': network.classifier_unknown.cpu().state_dict(),
            # 	'classifier_open_set': network.classifier_open_set.cpu().state_dict()
            # }

            torch.save(dict_to_save, os.path.join(os.path.join(settings['weights_path'])+'/', 'best_' + str(train_iter) + '.pth'))
            network.to(device)
            itt_delete.append(train_iter)

            # if(len(itt_delete)>100):
            #   for k in itt_delete[:-100]:
            #     subprocess.call(['rm', os.path.join(os.path.join(settings['weights_path'])+'/', 'best_' + str(k) + '.pth')])
            #   itt_delete = itt_delete[-100:]

            # Limit the number of saved checkpoints (keeping only the last 100)
            if len(itt_delete) > 100:
                for k in itt_delete[:-100]:
                    checkpoint_path = os.path.join(settings['weights_path'], 'best_' + str(k) + '.pth')
                    if os.path.exists(checkpoint_path):  # Check if the file exists before attempting to delete
                        os.remove(checkpoint_path)  # Use os.remove to delete files in Colab

                # Keep only the last 100 iterations
                itt_delete = itt_delete[-100:]

          if train_iter > settings['max_iter']:
            break

        train_iter += 1

      print("Train Acc: ", train_acc_list[settings['max_iter']])


    def test(trainer_G, iteration=None):
      global max_val_source_acc

      summary_dict=trainer_G.val_over_val_set()

      print('source set validation')
      print(summary_dict)
      val_acc_source = summary_dict['acc/source_avg']
      max_val_source_acc = max(val_acc_source,max_val_source_acc)
      trainer_G.log_errors('val')

      if( (max_val_source_acc==val_acc_source) ):
        return True
      else:
        return False

    # ====== CALL TRAINING AND VALIDATION PROCESS ======
    trainval(network, optimizer, exp_name, index_lists, settings)

if __name__ == '__main__':
	main()

######## SANITY CHECK ########
C: ['back_pack', 'calculator', 'keyboard', 'monitor', 'mouse', 'mug', 'bike', 'laptop_computer', 'headphones', 'projector']
Cs_dash: ['bike_helmet', 'bookcase', 'bottle', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'letter_tray', 'mobile_phone', 'paper_notebook']
Ct_dash: ['pen', 'phone', 'printer', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']
Fs_dims: 256
batch_size: 64
classification_weight: [1, 0.2]
cnn_to_use: resnet50
dataset_exp_name: usfda_office_31_DtoA
dataset_path: /content/drive/MyDrive/Office-31/usfda_office_31_DtoA/index_lists
device: cuda
exp_name: usfda_office_31_DtoA
gpu: 0
load_exp_name: None
load_weights: False
losses_after_enough_iters: ['logprob', 'logsoftmax', 'pos_sample_recon']
lr: 0.0001
max_iter: 10
negative_data_path: /content/drive/MyDrive/Office-31/usfda_office_31_DtoA/negative_images
negative_mask_path: /content/drive/MyDrive/Office-31/usfda_office_31_D

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 192MB/s]



 Setting up optimizers ...

optimizers: ['classification', 'logsoftmax', 'pos_img_recon', 'pos_sample_recon']


----------- train_iter 1 -----------


Applying loss logsoftmax
log errors phase: train

dict_keys(['acc/source_acc_pos_img', 'acc/source_acc_neg_img', 'acc/source_binary_acc_neg_img'])
acc/source_acc_neg_img : 0.0
acc/source_acc_pos_img : 0.0
acc/source_binary_acc_neg_img : 0.96875

----------- train_iter 2 -----------


Applying loss pos_img_recon
log errors phase: train

dict_keys(['acc/source_acc_pos_img', 'acc/source_acc_neg_img', 'acc/source_binary_acc_neg_img', 'loss/pos_img_recon'])
acc/source_acc_neg_img : 0.0
acc/source_acc_pos_img : 0.0
acc/source_binary_acc_neg_img : 0.890625
loss/pos_img_recon : 875.0618896484375

----------- train_iter 3 -----------


Applying loss pos_sample_recon
log errors phase: train

dict_keys(['acc/source_acc_pos_img', 'acc/source_acc_neg_img', 'acc/source_binary_acc_neg_img'])
acc/source_acc_neg_img : 0.0
acc/source_acc_pos_img : 0.0
ac

100%|██████████| 5/5 [00:05<00:00,  1.09s/it]
100%|██████████| 20/20 [00:00<00:00, 39.62it/s]



Applying loss logsoftmax
log errors phase: train

dict_keys(['acc/source_acc_pos_img', 'acc/source_acc_neg_img', 'acc/source_binary_acc_neg_img', 'acc/source_acc_pos_samples', 'loss/logsoftmax'])
acc/source_acc_neg_img : 0.0
acc/source_acc_pos_img : 0.234375
acc/source_acc_pos_samples : 0.21875
acc/source_binary_acc_neg_img : 0.875
loss/logsoftmax : -9.238740921020508
validating

Validating on source validation data


100%|██████████| 1/1 [00:31<00:00, 31.37s/it]


source set validation
{'acc/source_avg': 0.08333333333333333}
log errors phase: val

dict_keys(['acc/source_avg'])
acc/source_avg : 0.08333333333333333
min_val_flag True
Saving - iteration 5

----------- train_iter 6 -----------


Applying loss pos_img_recon
log errors phase: train

dict_keys(['acc/source_acc_pos_img', 'acc/source_acc_neg_img', 'acc/source_binary_acc_neg_img', 'acc/source_acc_pos_samples', 'loss/pos_img_recon'])
acc/source_acc_neg_img : 0.0
acc/source_acc_pos_img : 0.203125
acc/source_acc_pos_samples : 0.21875
acc/source_binary_acc_neg_img : 0.84375
loss/pos_img_recon : 798.2153930664062

----------- train_iter 7 -----------


Applying loss pos_sample_recon
log errors phase: train

dict_keys(['acc/source_acc_pos_img', 'acc/source_acc_neg_img', 'acc/source_binary_acc_neg_img', 'acc/source_acc_pos_samples', 'loss/pos_sample_recon'])
acc/source_acc_neg_img : 0.015625
acc/source_acc_pos_img : 0.328125
acc/source_acc_pos_samples : 0.21875
acc/source_binary_acc_neg_img : 0.7

100%|██████████| 5/5 [00:03<00:00,  1.31it/s]
100%|██████████| 20/20 [00:00<00:00, 90.68it/s]



Applying loss pos_img_recon
log errors phase: train

dict_keys(['acc/source_acc_pos_img', 'acc/source_acc_neg_img', 'acc/source_binary_acc_neg_img', 'acc/source_acc_pos_samples', 'loss/pos_img_recon'])
acc/source_acc_neg_img : 0.0
acc/source_acc_pos_img : 0.46875
acc/source_acc_pos_samples : 0.5625
acc/source_binary_acc_neg_img : 0.703125
loss/pos_img_recon : 789.9507446289062
validating

Validating on source validation data


100%|██████████| 1/1 [00:00<00:00,  1.32it/s]


source set validation
{'acc/source_avg': 0.39166666666666666}
log errors phase: val

dict_keys(['acc/source_avg'])
acc/source_avg : 0.39166666666666666
min_val_flag True
Saving - iteration 10

----------- train_iter 11 -----------


Applying loss pos_sample_recon
log errors phase: train

dict_keys(['acc/source_acc_pos_img', 'acc/source_acc_neg_img', 'acc/source_binary_acc_neg_img', 'acc/source_acc_pos_samples', 'loss/pos_sample_recon'])
acc/source_acc_neg_img : 0.015625
acc/source_acc_pos_img : 0.609375
acc/source_acc_pos_samples : 0.375
acc/source_binary_acc_neg_img : 0.765625
loss/pos_sample_recon : 161.9952392578125

----------- train_iter 12 -----------


Applying loss classification
log errors phase: train

dict_keys(['acc/source_acc_pos_img', 'acc/source_acc_neg_img', 'acc/source_binary_acc_neg_img', 'acc/source_acc_pos_samples', 'loss/classification'])
acc/source_acc_neg_img : 0.0
acc/source_acc_pos_img : 0.484375
acc/source_acc_pos_samples : 0.59375
acc/source_binary_acc_neg_im

100%|██████████| 5/5 [00:06<00:00,  1.21s/it]
100%|██████████| 20/20 [00:00<00:00, 85.00it/s]



Applying loss pos_sample_recon
log errors phase: train

dict_keys(['acc/source_acc_pos_img', 'acc/source_acc_neg_img', 'acc/source_binary_acc_neg_img', 'acc/source_acc_pos_samples', 'loss/pos_sample_recon'])
acc/source_acc_neg_img : 0.015625
acc/source_acc_pos_img : 0.71875
acc/source_acc_pos_samples : 0.6875
acc/source_binary_acc_neg_img : 0.6875
loss/pos_sample_recon : 114.83320617675781
validating

Validating on source validation data


100%|██████████| 1/1 [00:00<00:00,  1.45it/s]


source set validation
{'acc/source_avg': 0.4666666666666666}
log errors phase: val

dict_keys(['acc/source_avg'])
acc/source_avg : 0.4666666666666666
min_val_flag True
Saving - iteration 15
0.609375


In [8]:
import os
from glob import glob
from natsort import natsorted

import matplotlib
# matplotlib.use('Agg')

import torch

settings = {}

settings['start_iter'] = 1
settings['max_iter'] = 10

settings['C'] = ['back_pack', 'calculator', 'keyboard', 'monitor', 'mouse', 'mug', 'bike', 'laptop_computer', 'headphones', 'projector']
settings['Cs_dash'] = ['bike_helmet', 'bookcase', 'bottle', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'letter_tray', 'mobile_phone', 'paper_notebook']
settings['Ct_dash'] = ['pen', 'phone', 'printer', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']

settings['num_C'] = len(settings['C'])
settings['num_Cs_dash'] = len(settings['Cs_dash'])
settings['num_Ct_dash'] = len(settings['Ct_dash'])
settings['num_Cs'] = settings['num_C'] + settings['num_Cs_dash']
settings['num_Ct'] = settings['num_C'] + settings['num_Ct_dash']

settings['val_after'] = 5
settings['batch_size'] = 64

settings['num_positive_samples'] = 32
settings['num_negative_samples'] = 32
settings['num_positive_images'] = settings['batch_size']
settings['num_negative_images'] = settings['batch_size']

settings['cnn_to_use'] = 'resnet50'
settings['Fs_dims'] = 256
settings['softmax_temperature'] = 1
settings['online_augmentation_90_degrees'] = True # Used for online rotations in the data loader
settings['val_aug_imgs_mean_before_softmax'] = False
settings['val_aug_imgs_mean_after_softmax'] = True
settings['separate_target_validation_set'] = True
settings['target_train_val_split'] = 0.9

xor = lambda a, b: ((a and not(b)) or (not(a) and b))
assert xor(settings['val_aug_imgs_mean_after_softmax'], settings['val_aug_imgs_mean_before_softmax'])

# For adapt
settings['running_adapt'] = True
settings['load_weights'] = True
settings['load_exp_name'] = 'usfda_office_31_DtoA'
settings['exp_name'] = 'usfda_office_31_DtoA_adapt'

settings['optimizer'] = {
	'adaptation': ['Ft'],
}

settings['lambda'] = [1, 0.1]
settings['weight_computation_method'] = 2
settings['exponential_shift'] = 0

settings['use_loss'] = {
	'adaptation': True,
}

settings['to_train'] = {
	'M': False, # -> only upto a certain conv layer. Needs to be frozen. We'll retrain the later layers.
	'Fs': False,
	'Ft': True,
	'G': False,
	'Cs': False,
	'Cn': False,
}
settings['lr'] = 1e-4 # 1e-3 default

settings['lr'] = 1e-4
settings['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    settings['gpu'] = torch.cuda.current_device()
    torch.cuda.set_device(settings['gpu'])
else:
    settings['gpu'] = None  # Set to None when using CPU

# Print device information for verification
print(f"Using device: {settings['device']}")

settings['weights_path'] = os.path.join(server_root_path, 'Office-31')
settings['summaries_path'] = os.path.join(server_root_path, 'Office-31')

# Loading Pretrained Weights
if settings['load_weights']:
    best_weights = natsorted(glob(os.path.join(settings['weights_path'], settings['load_exp_name'], 'weights', '*.pth')))[-1]
    settings['load_weights_path'] = best_weights

settings['dataset_exp_name'] = 'usfda_office_31_DtoA'

# Dataset and Paths
settings['dataset_path'] = os.path.join(server_root_path, 'Office-31', settings['dataset_exp_name'], 'index_lists')
settings['negative_data_path'] = os.path.join(server_root_path, 'Office-31', settings['dataset_exp_name'], 'negative_images')
settings['negative_mask_path'] = os.path.join(server_root_path, 'Office-31', settings['dataset_exp_name'], 'negative_masks')

# Print settings to verify
print(settings)

Using device: cuda
{'start_iter': 1, 'max_iter': 10, 'C': ['back_pack', 'calculator', 'keyboard', 'monitor', 'mouse', 'mug', 'bike', 'laptop_computer', 'headphones', 'projector'], 'Cs_dash': ['bike_helmet', 'bookcase', 'bottle', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'letter_tray', 'mobile_phone', 'paper_notebook'], 'Ct_dash': ['pen', 'phone', 'printer', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can'], 'num_C': 10, 'num_Cs_dash': 10, 'num_Ct_dash': 11, 'num_Cs': 20, 'num_Ct': 21, 'val_after': 5, 'batch_size': 64, 'num_positive_samples': 32, 'num_negative_samples': 32, 'num_positive_images': 64, 'num_negative_images': 64, 'cnn_to_use': 'resnet50', 'Fs_dims': 256, 'softmax_temperature': 1, 'online_augmentation_90_degrees': True, 'val_aug_imgs_mean_before_softmax': False, 'val_aug_imgs_mean_after_softmax': True, 'separate_target_validation_set': True, 'target_train_val_split': 0.9, 'running_adapt': True, 'load_we

In [5]:
!pip install tensorboardX

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/101.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import sys
import os
from torchvision import transforms, utils
import time
import cv2
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from torch.distributions.multivariate_normal import MultivariateNormal
from skimage import io
import time
from tqdm import tqdm
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
import pdb
from sklearn.preprocessing import StandardScaler
from scipy.stats import norm
from scipy.interpolate import make_interp_spline, BSpline
from glob import glob

class TrainerG():

    def __init__(self, network, optimizer, exp_name, index_lists, settings):

        # Set the network and optimizer
        self.network = network
        self.to_train = settings['to_train']

        # Optimizers to use
        self.optimizer = optimizer
        self.which_optimizer = list(sorted(self.optimizer.keys()))
        print('\noptimizers: ' + str(self.which_optimizer) + '\n')

        # Save the settings
        self.settings = settings

        # Initialize the val and train writers
        self.val_writer = SummaryWriter(os.path.join(server_root_path, 'Office-31', exp_name, 'summaries', 'logdir_val'))
        self.train_writer = SummaryWriter(os.path.join(server_root_path, 'Office-31', exp_name, 'summaries', 'logdir_train'))

        # Extract commonly used settings
        self.batch_size = settings['batch_size']
        self.current_iteration = settings['start_iter']

        # Get the index lists
        [_, _, index_list_path_train_target, index_list_path_val_target, _, index_list_path_aug_target] = index_lists
        self.index_list_train_target = np.load(index_list_path_train_target)
        self.index_list_val_target = np.load(index_list_path_val_target)
        self.target_augmentation_dict = np.load(index_list_path_aug_target, allow_pickle=True).item() # dictionary is like {filename:[list, of, augmented, files, names]}

        temp_aug_dict = {}
        for k in self.target_augmentation_dict:
            temp_aug_dict[os.path.join(server_root_path, k)] = [os.path.join(server_root_path, fn) for fn in self.target_augmentation_dict[k]]
        self.target_augmentation_dict = temp_aug_dict

        # Ensure augmented images from target validation set are removed
        self.index_list_val_target = [s for s in self.index_list_val_target if s.split('/')[-1].split('_')[0] == 'category']

        # Get number of classes
        self.num_C = settings['num_C']
        self.num_Cs_dash = settings['num_Cs_dash']
        self.num_Ct_dash = settings['num_Ct_dash']
        self.num_Cs = settings['num_Cs']
        self.num_Ct = settings['num_Ct']

        # Initialize data loaders
        self.get_all_dataloaders()


    def get_all_dataloaders(self):

        dataset_train = TemplateDataset(self.index_list_train_target, transform=transforms.Compose([transforms.ToTensor()]), random_choice=False)
        self.loader_train = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=2)

        dataset_target_val = TemplateDataset(self.index_list_val_target, transform=transforms.Compose([transforms.ToTensor()]), random_choice=False)
        self.loader_target_val = DataLoader(dataset_target_val, batch_size=self.batch_size, shuffle=True, num_workers=2)


    def get_histogram(self, features, bin_size=0.01, normalize=True):

        # print(features)
        # print(len(features))

        F = [int(float(f)/bin_size) for f in features]
        D = {}
        for f in F:
            if f in D: D[f] += 1
            else: D[f] = 1
        x = np.array([float(k*bin_size) for k in sorted(D.keys())])
        y = np.array([float(D[f]) for f in sorted(D.keys())])

        if normalize:
            y = y / np.sum(y)

        return x, y


    def get_prefix(self):

        return ''


    def weight_computation_step(self, concat_softmax):

        # Only exp max conf
        if self.settings['weight_computation_method'] == 1:
            num_Cs = self.num_C + self.num_Cs_dash
            W, _ = torch.max(concat_softmax[:, :num_Cs], dim=-1)
            W = torch.exp(self.settings['exponential_shift'] + W)
            return W.squeeze(), (1-W).squeeze()

        # Exp max conf normalized over batch
        elif self.settings['weight_computation_method'] == 2:
            num_Cs = self.num_C + self.num_Cs_dash
            W, _ = torch.max(concat_softmax[:, :num_Cs], dim=-1)
            W1 = torch.exp(self.settings['exponential_shift'] + W)
            W1 = ( W1 - torch.min(W1) ) / ( torch.max(W1) - torch.min(W1) )
            W2 = torch.exp(self.settings['exponential_shift'] + 1-W)
            W2 = ( W2 - torch.min(W2) ) / ( torch.max(W2) - torch.min(W2) )
            return W1.squeeze(), W2.squeeze()

        # Sum conf in Cs / Cn
        elif self.settings['weight_computation_method'] == 3:
            num_Cs = self.num_C + self.num_Cs_dash
            W = torch.sum(concat_softmax[:, :num_Cs], dim=-1)
            return W.squeeze(), (1-W).squeeze()

        # Sum conf in Cs / Cn normalized
        elif self.settings['weight_computation_method'] == 4:
            num_Cs = self.num_C + self.num_Cs_dash
            W = torch.sum(concat_softmax[:, :num_Cs], dim=-1)
            W1 = W / torch.max(W)
            W2 = (1-W) / torch.max(1-W)
            return W1.squeeze(), W2.squeeze()

        # Sum conf in Cs / Cn exponential normalized
        elif self.settings['weight_computation_method'] == 5:
            num_Cs = self.num_C + self.num_Cs_dash
            W = torch.sum(concat_softmax[:, :num_Cs], dim=-1)
            W1 = torch.exp(self.settings['exponential_shift'] + W)
            W1 = W / torch.max(W)
            W2 = torch.exp(self.settings['exponential_shift'] + 1-W)
            W2 = (1-W) / torch.max(1-W)
            return W1.squeeze(), W2.squeeze()


    def get_loss(self, which_loss):

        if which_loss == 'adaptation':

            num_Cs = self.num_C + self.num_Cs_dash

            concat_outputs = torch.cat([self.features['Cs'], self.features['Cn']], dim=-1)
            y_cap = F.softmax(concat_outputs/self.settings['softmax_temperature'], dim=-1)

            w_concat_outputs = torch.cat([self.features['w_Cs'], self.features['w_Cn']], dim=-1)
            w_y_cap = F.softmax(w_concat_outputs/self.settings['softmax_temperature'], dim=-1)
            W1, W2 = self.weight_computation_step(w_y_cap)

            # Detach W from the graph
            W1 = torch.from_numpy(W1.cpu().detach().numpy()).to(device)

            # Soft binary entropy way of pushing samples to the corresponding regions
            y_cap_s = torch.sum(y_cap[:, :num_Cs], dim=-1)
            y_cap_n = 1 - y_cap_s

            Ld_1 = W1 * (-torch.log(y_cap_s)) + W2 * (-torch.log(y_cap_n))

            # Soft categorical entropy way of pushing samples to the corresponding regions
            y_tilde_s = F.softmax(self.features['Cs']/self.settings['softmax_temperature'], dim=-1)
            y_tilde_n = F.softmax(self.features['Cn']/self.settings['softmax_temperature'], dim=-1)

            H_s = - torch.sum(y_tilde_s * torch.log(y_tilde_s), dim=-1)
            H_n = - torch.sum(y_tilde_n * torch.log(y_tilde_n), dim=-1)

            Ld_2 = W1 * H_s + W2 * H_n

            l1, l2 = self.settings['lambda']

            loss_over_batch = Ld_1 * l1 + Ld_2 * l2

            loss = torch.mean( loss_over_batch , dim=0 )

        else:
            raise NotImplementedError('Not implemented loss function ' + str(which_loss))

        self.summary_dict['loss/' + str(which_loss)] = loss.data.cpu().numpy()
        return loss


    def loss(self):

        # ==================================
        # ====== Accuracy over images ======
        # ==================================

        # Target Accuracy - all images
        concat_outputs = torch.cat([self.features['Cs'], self.features['Cn']], dim=-1)
        concat_softmax = F.softmax(concat_outputs/self.settings['softmax_temperature'], dim=-1)

        pred = torch.argmax(concat_softmax, dim=-1)
        pred[pred >= (self.num_C + self.num_Cs_dash)] = (self.num_C + self.num_Cs_dash)

        target_acc = (pred.float() == self.gt.float()).float().mean()
        self.summary_dict['acc/target_acc'] = target_acc

        if self.phase == 'train':

            # ====== BACKPROP LOSSES ======
            enough_iters = (self.current_iteration >= self.settings['val_after'])
            l = len(self.which_optimizer)
            current_loss = self.which_optimizer[self.current_iteration%l]

            if self.settings['use_loss'][current_loss] and self.backward:

                print('\nApplying loss ' + str(self.which_optimizer[self.current_iteration%l]))

                self.optimizer[self.which_optimizer[self.current_iteration%l]].zero_grad()
                loss = self.get_loss(which_loss=self.which_optimizer[self.current_iteration%l])
                loss.backward()
                self.optimizer[self.which_optimizer[self.current_iteration%l]].step()

        self.current_iteration += 1


    def forward(self):

        self.gt = Variable(torch.LongTensor(self.data['label'])).to(device).float()
        self.img_target = Variable(self.data['img'][:, :3, :, :]).to(device).float()
        self.gt[self.gt >= self.num_C] = (self.num_C + self.num_Cs_dash) # Club all the target private classes into an unknown class

        self.features = {}

        # Target data
        self.features['M'] = self.network.M(self.img_target)
        self.features['Ft'] = self.network.Ft(self.features['M'])
        self.features['Cs'] = self.network.Cs(self.features['Ft'])
        self.features['Cn'] = self.network.Cn(self.features['Ft'])

        # Passing target data through source cfier for getting the weight
        with torch.no_grad():
            self.features['w_Fs'] = self.network.Fs(self.features['M'])
            self.features['w_Cs'] = self.network.Cs(self.features['w_Fs'])
            self.features['w_Cn'] = self.network.Cn(self.features['w_Fs'])


    def train(self):

        self.phase = 'train'

        self.summary_dict = {}

        try:
            self.data = next(self.dataloader_train)[1]
            if self.data['img'].shape[0] < self.settings['batch_size']:
                self.dataloader_train = enumerate(self.loader_train)
                self.data = next(self.dataloader_train)[1]
        except:
            self.dataloader_train = enumerate(self.loader_train)
            self.data = next(self.dataloader_train)[1]

        self.forward()
        self.loss()

        return self.summary_dict['acc/target_acc']


    def log_errors(self, phase, iteration=None):

        print('log errors phase: ' + str(phase) + '\n')
        print(self.summary_dict.keys())

        for x in self.summary_dict.keys():
            print(x + ' : ' + str(float(self.summary_dict[x])))

            if phase == 'val':
                self.val_writer.add_scalar(self.get_prefix() + x, self.summary_dict[x], self.current_iteration)
            elif phase == 'train':
                self.train_writer.add_scalar(self.get_prefix() + x, self.summary_dict[x], self.current_iteration)


    def set_mode_val(self):

        self.network.eval()
        self.backward = False
        for p in self.network.parameters():
            p.requires_grad = False
            p.volatile = True


    def set_mode_train(self):

        self.network.train()
        self.backward = True
        for p in self.network.parameters():
            p.requires_grad = True
            p.volatile = False

        for comp in self.settings['to_train']:
            if self.settings['to_train'][comp] == False:
                self.network.components[comp].eval()
                for p in self.network.components[comp].parameters():
                    p.requires_grad = False
                    p.volatile = True


    def val_over_val_set(self):

        with torch.no_grad():

            enough_iters = (self.current_iteration >= self.settings['val_after'])

            self.summary_dict = {}

            # --------------
            # Target Dataset
            # --------------

            print('\nValidating on target validation data')

            dataset_target_val = TemplateDataset(self.index_list_val_target, transform=transforms.Compose([transforms.ToTensor()]), random_choice=False)
            dataloader_target = DataLoader(dataset_target_val, batch_size=self.batch_size, shuffle=True, num_workers=2)

            num_C = self.num_C
            num_Cs_dash = self.num_Cs_dash
            num_Ct_dash = self.num_Ct_dash
            num_Cs = self.num_Cs
            num_Ct = self.num_Ct

            with torch.no_grad():

                # Running calculations of accuracy

                a_private_acc = 0
                private_count = 0

                classes = list(range(num_C))
                classes.append(num_C+num_Cs_dash)

                a_avg_count = {c:0 for c in classes}
                a_avg_acc = {c:0 for c in classes}

                idx = -1

                for data in tqdm(dataloader_target):
                    idx += 1
                    x = Variable(data['img'][:, :3, :, :]).to(self.settings['device']).float()
                    labels_target = Variable(data['label']).to(self.settings['device'])
                    labels_target[labels_target>=num_C] = (self.num_C + self.num_Cs_dash) # The index corresponding to the Cn logit
                    fnames = data['filename']

                    M = self.network.components['M'](x)
                    Ft = self.network.components['Ft'](M)
                    G = self.network.components['G'](Ft)
                    Cs = self.network.components['Cs'](Ft)
                    Cn = self.network.components['Cn'](Ft)

                    concat_outputs = torch.cat([Cs, Cn], dim=-1)
                    concat_softmax = F.softmax(concat_outputs/self.settings['softmax_temperature'], dim=-1)

                    max_act, pred = torch.max(concat_softmax, dim=-1)

                    pred[pred>=(num_Cs)] = (self.num_C + self.num_Cs_dash) # Club all the negative classes into one

                    private_count += pred[labels_target>=num_C].shape[0]

                    for c in classes: # for all the shared and the unknown labels
                        a_avg_count[c] += pred[labels_target==c].shape[0]

                    a_pred = []
                    for i, original_fname in enumerate(fnames):

                        aug_imgs = []
                        for fname in self.target_augmentation_dict[original_fname]:
                            image = io.imread(fname)
                            res = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
                            image = res(image)
                            aug_imgs.append(image.view((1, 3, 224, 224)))
                        aug_imgs = torch.cat(aug_imgs, dim=0).to(device)

                        a_M = self.network.components['M'](aug_imgs)
                        a_Ft = self.network.components['Ft'](a_M)
                        a_Cs = self.network.components['Cs'](a_Ft)
                        a_Cn = self.network.components['Cn'](a_Ft)

                        a_concat_outputs = torch.cat([a_Cs, a_Cn], dim=-1)

                        if self.settings['val_aug_imgs_mean_before_softmax']:
                            a_concat_outputs = torch.mean(a_concat_outputs, dim=0)
                        a_concat_softmax = F.softmax(a_concat_outputs/self.settings['softmax_temperature'], dim=-1)

                        if self.settings['val_aug_imgs_mean_after_softmax']:
                            a_concat_softmax = torch.mean(a_concat_softmax, dim=0)

                        a_pred_temp = torch.LongTensor([int(torch.argmax(a_concat_softmax, dim=-1))]).to(device)
                        if a_pred_temp >= num_Cs:
                            a_pred_temp = torch.LongTensor([self.num_C + self.num_Cs_dash]).to(device)
                        a_pred.append(a_pred_temp)

                    a_pred = torch.cat(a_pred, dim=0).squeeze()

                    a_private_acc += (a_pred[labels_target>=num_C] == labels_target[labels_target>=num_C]).float().sum()

                    for c in classes:
                        a_avg_acc[c] += (a_pred[labels_target==c] == labels_target[labels_target==c]).float().sum()

                self.summary_dict['acc/target_a_private'] = float(a_private_acc) / float(private_count)

                # average accuracy
                a_avg = 0
                num_classes = num_C + 1
                classes = list(range(num_C))
                classes.append(num_C+num_Cs_dash)
                for c in classes:
                    if a_avg_count[c] == 0:
                        a_avg += 0
                    else:
                        a_avg += (float(a_avg_acc[c]) / float(a_avg_count[c]))
                a_avg /= float(num_classes)
                self.summary_dict['acc/target_a_avg'] = a_avg

            return self.summary_dict['acc/target_a_avg']

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import sys
import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from torch.utils.data import DataLoader
from time import time
from tqdm import tqdm
import subprocess
import warnings
from torchvision import transforms
from tensorboardX import SummaryWriter
import shutil

warnings.simplefilter("ignore", UserWarning)

# ======================= SANITY CHECK ======================= #

assert settings['running_adapt'], 'ERROR!! Config not set to run adapt trainer!!'
print('######## SANITY CHECK ########')
for key in sorted(settings.keys()):
	print('{}: {}'.format(key, settings[key]))

# ip = raw_input('continue? (y/n): ')
# if ip.lower() == 'y' or ip.lower() == 'yes':
# 	pass
# else:
# 	print('Decided not to execute!')
# 	exit()

# ==================== END OF SANITY CHECK ==================== #

max_val_acc = -10000
itt_delete = []

def main():

    print('\n Setting up data sources ...')

    # ====== DELETE PAST RUNS ======
    #torch.cuda.set_device(settings['gpu'])

    exp_name = settings['exp_name']

    # subprocess.call(["rm", "-rf", os.path.join(settings['weights_path'],exp_name)])
    # subprocess.call(["mkdir", os.path.join(settings['weights_path'],exp_name)])
    # subprocess.call(["rm", "-rf", os.path.join(settings['summaries_path'],exp_name)])
    # subprocess.call(["mkdir", os.path.join(settings['summaries_path'],exp_name)])
    # subprocess.call(["mkdir", os.path.join(settings['summaries_path'],exp_name)+"/logdir_train"])
    # subprocess.call(["mkdir", os.path.join(settings['summaries_path'],exp_name)+"/logdir_val"])

    # with open(os.path.join(settings['summaries_path'], settings['exp_name'], 'txt'), 'w') as history_file:
    # 	print('saving in ' + os.path.join(settings['summaries_path'], settings['exp_name'], 'txt'))
    # 	history_file.write('\n===== x ===== x =====\n')
    # 	for key in sorted(settings.keys()):
    # 		history_file.write('{}: {}\n'.format(key, settings[key]))

    # Define the paths for weights and summaries
    weights_path = settings['weights_path']
    summaries_path = settings['summaries_path']

    # Remove the existing directory for experiment's weights and summaries if they exist
    weights_dir = os.path.join(weights_path, exp_name)
    if os.path.exists(weights_dir):
        shutil.rmtree(weights_dir)  # Delete the directory and its contents
    os.makedirs(weights_dir)  # Create a new empty directory for weights

    summaries_dir = os.path.join(summaries_path, exp_name)
    if os.path.exists(summaries_dir):
        shutil.rmtree(summaries_dir)  # Delete the directory and its contents
    os.makedirs(summaries_dir)  # Create a new empty directory for summaries

    # Create subdirectories for training and validation logs
    train_log_dir = os.path.join(summaries_dir, "logdir_train")
    val_log_dir = os.path.join(summaries_dir, "logdir_val")
    os.makedirs(train_log_dir)  # Create training log directory
    os.makedirs(val_log_dir)  # Create validation log directory

    with open(os.path.join(settings['summaries_path'], exp_name, 'txt'), 'w') as history_file:
      print('saving in ' + os.path.join(settings['summaries_path'], exp_name, 'txt'))
      history_file.write('\n===== x ===== x =====\n')
      for key in sorted(settings.keys()):
        history_file.write('{}: {}\n'.format(key, settings[key]))

    # ====== DEFINE DATA SOURCES ======
    index_list_path_train_target = os.path.join(settings['dataset_path'], 'target_images_index_list_train.npy')
    index_list_path_val_target = os.path.join(settings['dataset_path'], 'target_images_index_list_train.npy')
    index_list_path_aug_target = os.path.join(settings['dataset_path'], 'target_images_aug_dict_train.npy')
    index_lists = [None, None, index_list_path_train_target, index_list_path_val_target, None, index_list_path_aug_target]

    # ====== CREATE NETWORK ======
    print('\n Building network ...')
    network = modnet(settings['num_C'], settings['num_Cs_dash'], settings['num_Ct_dash'], cnn=settings['cnn_to_use']).to(device)

    # Load weights
    if settings['load_weights']:
      dict_to_load = torch.load(settings['load_weights_path'])
      for component in dict_to_load:

        network.components[component].load_state_dict(dict_to_load[component])

    # Initialize weights from source networks if we are loading from supervised experiment
    if settings['load_exp_name'].split('_')[-1] != 'adapt':
      # Initialize Ft from Fs
      network.components['Ft'].load_state_dict(network.components['Fs'].state_dict())

    # ====== DEFINE OPTIMIZERS ======
    print('\n Setting up optimizers ...')
    optimizer = {}

    for key in settings['use_loss']:
      if settings['use_loss'][key]:
        to_train = []
        for comp in settings['optimizer'][key]:
          if settings['to_train'][comp]:
            to_train.append({'params': network.components[comp].parameters(), 'lr':settings['lr']})
        optimizer[key] = optim.Adam(params = to_train)

    # ====== CALL TRAINING AND VALIDATION PROCESS ======

    def trainval(network, optimizer, exp_name, index_lists, settings):

      global least_val_loss
      global itt_delete

      train_iter = settings['start_iter']

      trainer_G = TrainerG(network, optimizer, exp_name, index_lists, settings)

      train_acc_list = []

      while True:

        print ("\n----------- train_iter " + str(train_iter) + ' -----------\n')
        trainer_G.set_mode_train()
        acc_gen = trainer_G.train()
        trainer_G.log_errors('train')
        train_acc_list.append(acc_gen.data.cpu().numpy())

        if train_iter%settings['val_after'] == 0:

          print('validating')

          trainer_G.set_mode_val()
          min_val_flag=test(trainer_G)

          print('min_val_flag', min_val_flag)

          if(min_val_flag):
            print ("Saving - iteration", train_iter)
            dict_to_save = {component:network.components[component].cpu().state_dict() for component in network.components}
            # dict_to_save = {
            # 	'vgg': network.vgg.cpu().state_dict(),
            # 	'feature_ext': network.feature_ext.cpu().state_dict(),
            # 	'classifier_known': network.classifier_known.cpu().state_dict(),
            # 	'dec_feature_ext': network.dec_feature_ext.cpu().state_dict(),
            # 	'classifier_unknown': network.classifier_unknown.cpu().state_dict(),
            # 	'classifier_open_set': network.classifier_open_set.cpu().state_dict()
            # }
            torch.save(dict_to_save, os.path.join(os.path.join(settings['weights_path'],exp_name)+'/', 'best_' + str(train_iter) + '.pth'))
            network.to(device)
            itt_delete.append(train_iter)

            # Limit the number of saved checkpoints (keeping only the last 100)
            if len(itt_delete) > 100:
                for k in itt_delete[:-100]:
                    checkpoint_path = os.path.join(settings['weights_path'], exp_name, 'best_' + str(k) + '.pth')
                    if os.path.exists(checkpoint_path):  # Check if the file exists before attempting to delete
                        os.remove(checkpoint_path)  # Use os.remove to delete files in Colab

                # Keep only the last 100 iterations
                itt_delete = itt_delete[-100:]

          if train_iter > settings['max_iter']:
            break

        train_iter += 1

      print("Train Acc: ", train_acc_list[settings['max_iter']])


    def val_on_target_set(trainer_G_val_target, iteration):

      global least_val_loss
      global itt_delete

      train_iter = 1

      trainer_G = trainer_G_val_target

      trainer_G.set_mode_val()
      min_val_flag=test(trainer_G, target=True, iteration=iteration)


    def test(trainer_G, target=False, iteration=None):
      global max_val_acc
      val_record=trainer_G.val_over_val_set()

      print(val_record, max_val_acc)

      print('source set validation')
      print(val_record)
      val_acc = val_record
      # val_acc=val_record['source_classifier/accuracy_source_classifier_known_classes']
      max_val_acc=max(val_acc,max_val_acc)
      trainer_G.log_errors('val')

      if(max_val_acc==val_acc):
        return True
      else:
        return False

    trainval(network, optimizer, exp_name, index_lists, settings)


if __name__ == '__main__':
	main()

######## SANITY CHECK ########
C: ['back_pack', 'calculator', 'keyboard', 'monitor', 'mouse', 'mug', 'bike', 'laptop_computer', 'headphones', 'projector']
Cs_dash: ['bike_helmet', 'bookcase', 'bottle', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'letter_tray', 'mobile_phone', 'paper_notebook']
Ct_dash: ['pen', 'phone', 'printer', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']
Fs_dims: 256
batch_size: 64
cnn_to_use: resnet50
dataset_exp_name: usfda_office_31_DtoA
dataset_path: /content/drive/MyDrive/Office-31/usfda_office_31_DtoA/index_lists
device: cuda
exp_name: usfda_office_31_DtoA_adapt
exponential_shift: 0
gpu: 0
lambda: [1, 0.1]
load_exp_name: usfda_office_31_DtoA
load_weights: True
load_weights_path: /content/drive/MyDrive/Office-31/usfda_office_31_DtoA/weights/best_15.pth
lr: 0.0001
max_iter: 10
negative_data_path: /content/drive/MyDrive/Office-31/usfda_office_31_DtoA/negative_images
negative_mask_path: /co

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 148MB/s]
  dict_to_load = torch.load(settings['load_weights_path'])



 Setting up optimizers ...

optimizers: ['adaptation']


----------- train_iter 1 -----------


Applying loss adaptation
log errors phase: train

dict_keys(['acc/target_acc', 'loss/adaptation'])
acc/target_acc : 0.5625
loss/adaptation : 1.1391979455947876

----------- train_iter 2 -----------


Applying loss adaptation
log errors phase: train

dict_keys(['acc/target_acc', 'loss/adaptation'])
acc/target_acc : 0.5
loss/adaptation : 1.031150460243225

----------- train_iter 3 -----------


Applying loss adaptation
log errors phase: train

dict_keys(['acc/target_acc', 'loss/adaptation'])
acc/target_acc : 0.625
loss/adaptation : 0.9886326193809509

----------- train_iter 4 -----------


Applying loss adaptation
log errors phase: train

dict_keys(['acc/target_acc', 'loss/adaptation'])
acc/target_acc : 0.625
loss/adaptation : 0.9897570610046387

----------- train_iter 5 -----------


Applying loss adaptation
log errors phase: train

dict_keys(['acc/target_acc', 'loss/adaptation'])
acc/target

  3%|▎         | 1/31 [03:41<1:50:55, 221.86s/it]