In [1]:
import os
os.chdir('../NMCE')
import sys
sys.path.append('./')
sys.path.append('/cis/home/tding/mcr2/mcr2_cluster/')

import wandb
import argparse
from tqdm import tqdm

import torch
import numpy as np
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader

from architectures.models_se import SinkhornDistance
from data.datasets import load_dataset
from loss import MaximalCodingRateReduction
from func import chunk_avg, cluster_acc

import utils

import random, string

from metrics_cluster import rect_pi_metrics, compute_numerical_rank, spectral_clustering_metrics, feature_detection,     sparsity, numerical_rank_from_singular_values
import plot

from utils_my import *
from augmentloader import AugmentLoader
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#%%
parser = argparse.ArgumentParser(description='Unsupervised Learning')
parser.add_argument('--arch', type=str, default='MLP3D_100',
                    help='architecture for deep neural network (default: resnet18cifar)')
parser.add_argument('--z_dim', type=int, default=3,
                    help='dimension of subspace feature dimension (default: 64)')
parser.add_argument('--n_clusters', type=int, default=2,
                    help='number of subspace clusters to use (default: 10)')
parser.add_argument('--data', type=str, default='synthetic',
                    help='dataset used for training (default: cifar10)')
parser.add_argument('--aug_name', type=str, default='cifar_simclr_norm',
                    help='name of augmentation to use')
parser.add_argument('--epo', type=int, default=500,
                    help='number of epochs for training (default: 100)')
parser.add_argument('--load_epo', type=int, default=600,
                    help='epo to load pre-trained checkpoint from')
parser.add_argument('--train_backbone', action='store_true',
                    help='whether to also train parameters in backbone')
parser.add_argument('--validate_every', type=int, default=5,
                    help='validate clustering accuracy every this epochs and save results (default: 10)')
parser.add_argument('--bs', type=int, default=200,
                    help='input batch size for training (default: 1000)')
parser.add_argument('--w1', type=float, default=5., help='frequency of manifold one')
parser.add_argument('--A1', type=float, default=0.2, help='amplitude of manifold one')
parser.add_argument('--noise', type=float, default=0.05, help='noise for manifold data')
parser.add_argument('--n_views', type=int, default=2,
                    help='number of augmentations per sample')
parser.add_argument('--lr', type=float, default=0.0005,
                    help='learning rate (default: 0.001)')
parser.add_argument('--lrpi', type=float, default=0.0005,
                    help='learning rate (default: 0.001)')
parser.add_argument('--momo', type=float, default=0.2,
                    help='momentum (default: 0.9)')
parser.add_argument('--wd1', type=float, default=1e-4,
                    help='weight decay for all other parameters except clustering head(default: 1e-4)')
parser.add_argument('--wd2', type=float, default=0,
                    help='weight decay for clustering head (default: 5e-3)')
parser.add_argument('--eps', type=float, default=0.01,
                    help='eps squared for MCR2 objective (default: 0.1)')
parser.add_argument('--tau', type=float, default=1,
                    help='temperature for gumble softmax (default: 1)')
parser.add_argument('--z_weight', type=float, default=10,
                    help='weight for z_sim loss (default: 100)')
parser.add_argument('--doc', type=str, default='se_synthetic',
                    help='extra information to add to folder name')
parser.add_argument('--save_dir', type=str, default='./exps/',
                    help='base directory for saving PyTorch model. (default: ./exps/)')
parser.add_argument('--data_dir', type=str, default='/cis/home/tding/data/',
                    help='path to dataset folder')
parser.add_argument('--gpu_ids', default=[0,1], type=eval,
                    help='IDs of GPUs to use')
parser.add_argument('--fp16', action='store_true',
                    help='Whether or not to use 16-bit precision GPU training.')
parser.add_argument('--seed', type=int, default=3,
                    help='random seed')
parser.add_argument('--pieta', type=float, default=0.3,
                    help='pi sparsity (smaller, sparser)')
parser.add_argument('--pigam', type=float, default=0.2,
                    help='pi reg (smaller, sparser)')
parser.add_argument('--piiter', type=int, default=20,
                    help='sinkhorn iterations')
parser.add_argument('--selfsupiter', type=int, default=20,
                    help='sinkhorn iterations')
args = parser.parse_args(['--seed', '0'])

wandb.init(project="synthetic", entity="mcr2_clustering", name=f's2-A{args.A1}-w{args.w1}-noise{args.noise}')
wandb.config.update(args)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtianjiaoding[0m ([33mmcr2_clustering[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
#%%
def generate_manifolds_on_the_sphere(n_samples_each, w1=5, A1=0.0, noise=0.):
    num_classes = 2

    # class 1
    phi1 = np.linspace(0, 2 * np.pi, n_samples_each)
    theta1 = A1 * np.sin(w1 * phi1)

    x1 = np.cos(theta1) * np.cos(phi1) + noise * np.random.randn(n_samples_each)
    y1 = np.cos(theta1) * np.sin(phi1) + noise * np.random.randn(n_samples_each)
    z1 = np.sin(theta1) + noise * np.random.randn(n_samples_each)

    # class 2
    x2 = noise * np.random.randn(n_samples_each) + 0
    y2 = noise * np.random.randn(n_samples_each) + 0
    z2 = noise * np.random.randn(n_samples_each) + 1

    xs = np.concatenate((x1, x2)).reshape(-1, 1)
    ys = np.concatenate((y1, y2)).reshape(-1, 1)
    zs = np.concatenate((z1, z2)).reshape(-1, 1)

    X_train = torch.Tensor(np.concatenate((xs, ys, zs), axis=1))
    X_train = torch.nn.functional.normalize(X_train, dim=1)

    y_train = torch.zeros(num_classes * n_samples_each, dtype=int)
    for i in range(num_classes):
        class_inds = list(range(i * n_samples_each, (i + 1) * n_samples_each))
        y_train[class_inds] = i
    return X_train, y_train, num_classes


from torch.utils.data import Dataset as torchDataset
class Dataset(torchDataset):
    def __init__(self, n_samples_each, w1, A1, noise):

        self.n_samples_each = n_samples_each
        self.w1 = w1
        self.A1 = A1
        self.noise = noise
        self.data, self.targets, _ = generate_manifolds_on_the_sphere(n_samples_each, w1=w1, A1=A1, noise=noise)
        self._num_samples = len(self.targets)
    def __getitem__(self, key):
        return self.data[key], self.targets[key]

from torch import nn
import functional as F
class MLP3D_100(nn.Module):
    def __init__(self):
        super(MLP3D_100, self).__init__()
        self.lins_z = nn.Sequential(
            nn.Linear(3, 100),
            nn.ReLU(),
#             nn.BatchNorm1d(100),
            nn.Linear(100, 3),
        )
        self.lins_pi = nn.Sequential(
            nn.Linear(3, 100),
            nn.ReLU(),
#             nn.BatchNorm1d(100),
            nn.Linear(100, 3),
        )
#         self.w = nn.Linear(3, 3)
#         with torch.no_grad():
#             self.w.weight.copy_(torch.eye(3))


    def forward(self, x, sink_eta=0.1, sink_iter=5):
        z = F.normalize(self.lins_z(x))
        logits = F.normalize(self.lins_pi(x))

        self_coeff = (logits @ logits.T).unsqueeze(0)
        # print(f'self_coeff.shape: {self_coeff.shape}')
        sink_layer = SinkhornDistance(sink_eta, max_iter=sink_iter)
        Pi = sink_layer(self_coeff)[0]
        # print(f'Pi.shape{Pi.shape}')
        Pi = Pi * Pi.shape[-1]


        return z, Pi

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


#GPU setup
device = 'cuda:'+str(args.gpu_ids[0])
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True

letters = string.ascii_letters
rand_str = ''.join(random.choice(letters) for i in range(10))

## get model directory
model_dir = os.path.join(args.save_dir,
               'selfsup_{}_{}_{}_{}'.format(
                    args.arch, args.data, args.doc, rand_str))

os.makedirs(model_dir, exist_ok=True)
os.makedirs(model_dir + '/cluster_imgs/', exist_ok=True)
os.makedirs(model_dir + '/pca_figures/', exist_ok=True)
os.makedirs(model_dir + '/checkpoints/', exist_ok=True)

wandb.config.update({"model_dir": model_dir, "model_file": "models_se"})

In [4]:
def print_link(model_dir):
    base_url = 'http://gia-e.tianjiaoding.com:65300/'
    print(base_url+model_dir[7:])
print_link(model_dir)

http://gia-e.tianjiaoding.com:65300/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz


In [5]:
#model
net = MLP3D_100()
net = nn.DataParallel(net,args.gpu_ids)
net = net.to(device)

#data
train_dataset = Dataset(args.bs // 2, args.w1, args.A1, args.noise)
test_dataset = train_dataset

transform = transforms.Compose([
        AddGaussianNoise(0., 0.01)
    ])

# train_dataset = load_dataset(args.data,args.aug_name,use_baseline=False,train=True,contrastive=True if args.n_views>1 else False,n_views=args.n_views,path=args.data_dir)
# test_data = 'stl10' if args.data in ['stl10unsup','stl10'] else args.data
# test_dataset = load_dataset(test_data,args.aug_name,use_baseline=True,train=True,contrastive=False,path=args.data_dir)
# train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True, num_workers=8)
# test_loader = DataLoader(test_dataset, batch_size=args.bs, shuffle=True, drop_last=False, num_workers=8)
train_loader = AugmentLoader(train_dataset,
                            transforms=transform,
                            sampler="random",
                            batch_size=args.bs*args.n_views,
                            num_aug=args.n_views,
                            shuffle=True,
                            is_tensor=True)


#loss
criterion = MaximalCodingRateReduction(eps=args.eps,gamma=1.0)

# wandb.config.update({"load_ckpt": load_path})

#only optimize cluster and subspace module
print(args.train_backbone)
para_list = [p for p in net.module.lins_z.parameters()]
para_list_c = [p for p in net.module.lins_pi.parameters()]

optimizer = optim.SGD(para_list, lr=args.lr, momentum=args.momo, weight_decay=args.wd1,nesterov=False)

optimizerc = optim.SGD(para_list_c, lr=args.lrpi, momentum=args.momo, weight_decay=args.wd2,nesterov=False)
scaler = GradScaler()

False


In [6]:
def update_pi_from_z(net):
    import copy
    model_dict = net.state_dict()
    save_dict = copy.deepcopy(model_dict)
    to_rename_keys = []
    for key in save_dict:
    #     if 'cluster' in key:
    #         to_del_keys.append(key)
        if 'lins_z' in key:
            to_rename_keys.append(key)

    # for key in to_del_keys:
    #     del save_dict[key]
    #     print(f'deleted key {key}')

    for key in to_rename_keys:
        print(f'renamed key {key}')
        pre, post = key.split('lins_z')
        save_dict[pre + 'lins_pi' + post] = save_dict.pop(key)

    model_dict.update(save_dict)
    log = net.load_state_dict(model_dict)
    print(log)
    return net

In [7]:
n_samples = args.bs
perm = torch.arange(n_samples * args.n_views).reshape(n_samples, args.n_views).transpose(1, 0).flatten()
net = update_pi_from_z(net)

renamed key module.lins_z.0.weight
renamed key module.lins_z.0.bias
renamed key module.lins_z.2.weight
renamed key module.lins_z.2.bias
<All keys matched successfully>


In [8]:
def plot_3d_my(model_dir, Z, y, name):
    colors = np.array(['green', 'blue', 'red'])
    save_dir = os.path.join(model_dir, 'figures', 'features')
    os.makedirs(save_dir, exist_ok=True)
    colors = np.array(['forestgreen', 'royalblue', 'brown'])
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(Z[:, 0], Z[:, 1], Z[:, 2], c=colors[y], cmap=plt.cm.Spectral, s=200.0)
    # Z, _ = F.get_n_each(Z, y, 1)
    # for c in np.unique(y):
        # ax.quiver(0.0, 0.0, 0.0, Z[c, 0], Z[c, 1], Z[c, 2], length=1.0, normalize=True, arrow_length_ratio=0.05, color='black')
    u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j]
    x = np.cos(u)*np.sin(v)
    y = np.sin(u)*np.sin(v)
    z = np.cos(v)
    ax.plot_wireframe(x, y, z, color="gray", alpha=0.5)
    ax.xaxis._axinfo["grid"]['color'] =  (0,0,0,0.1)
    ax.yaxis._axinfo["grid"]['color'] =  (0,0,0,0.1)
    ax.zaxis._axinfo["grid"]['color'] =  (0,0,0,0.1)
    my_ticks = [-1,-0.5, 0, 0.5, 1]
    my_tickpad = 10
    ax.set_xticks(my_ticks)
    ax.set_yticks(my_ticks)
    ax.set_zticks(my_ticks)
    ax.tick_params(axis='x', which='major', pad=my_tickpad)
    ax.tick_params(axis='y', which='major', pad=my_tickpad)
    ax.tick_params(axis='z', which='major', pad=my_tickpad)
    [tick.label.set_fontsize(24) for tick in ax.xaxis.get_major_ticks()] 
    [tick.label.set_fontsize(24) for tick in ax.yaxis.get_major_ticks()]
    [tick.label.set_fontsize(24) for tick in ax.zaxis.get_major_ticks()]
    ax.view_init(20, 15)
    plt.tight_layout()
    fig.savefig(os.path.join(save_dir, f"{name}.jpg"), dpi=200)
    plt.close()
# plot_3d_my(model_dir, z_avg.detach().cpu().numpy(), y_np, f'epoch{epoch}_batch{step}')

In [9]:
## Training
for epoch in range(args.epo):
    wandb.log({"epoch": epoch, "lr_z": optimizer.param_groups[0]['lr'], "lr_c": optimizerc.param_groups[0]['lr']})
    for step, (x, y, y_selfsup, sample_idx) in enumerate(train_loader):
        x = x[perm]
        y = y[perm][:n_samples]
        x, y = x.float().to(device), y.to(device)

        y_np = y.detach().cpu().numpy()

        if epoch == 0 and step == 0:
            plot_3d_my(model_dir, x[:n_samples].detach().cpu().numpy(), y.detach().cpu().numpy(), f'X_train')
        if epoch == args.selfsupiter+1 and step ==0:
            net = update_pi_from_z(net)
            print('copied z weights to pi weights')
            
        with autocast(enabled=args.fp16):
            z, Pi = net(x, sink_eta=args.pieta, sink_iter=args.piiter)
            print(f'z.shape: {z.shape}')
#                 print(f'Pi.shape: {Pi.shape}')
            Pi = (Pi[0] + Pi[1]) / 2

            print(f'Pi.shape: {Pi.shape}')
            z_avg = chunk_avg(z,n_chunks=args.n_views,normalize=True)

            # metrics on z
            z_list = z.chunk(args.n_views, dim=0)
            z_sim = (z_list[0] * z_list[1]).sum(1).mean()

            sim_mat = z_list[0] @ z_list[1].T
#             sim_mat = z @ z.T
            sim_mat = sim_mat.detach()
            spe, nnz = metrics(sim_mat, y_np)
            wandb.log({"z_spe": spe, "z_nnz": nnz})

            # metrics on pi
            Pi_np = Pi.detach().cpu().numpy()
            pi_spe = feature_detection(Pi_np, y_np)
            nnz_2 = sparsity(Pi_np, 1e-2)
            nnz_3 = sparsity(Pi_np, 1e-3)
            nnz_4 = sparsity(Pi_np, 1e-4)
            nnz_5 = sparsity(Pi_np, 1e-5)
            nnz_6 = sparsity(Pi_np, 1e-6)

            wandb.log({"pi_spe": pi_spe, "pi_nnz_2": nnz_2, "pi_nnz_3": nnz_3,
                       "pi_nnz_4": nnz_4, "pi_nnz_5": nnz_5, "pi_nnz_6": nnz_6})

            if epoch <= 10 or epoch % 5 == 0:
                acc_lst, nmi_lst, _, _, pred_lst = spectral_clustering_metrics(Pi_np, args.n_clusters, y_np)
                #                 pred_order = np.argsort(pred_lst[-1])
                wandb.log({"pi_acc": np.mean(acc_lst), "pi_nmi": np.mean(nmi_lst), "pi_nnz_3": nnz_3,
                           "pi_nnz_4": nnz_4, "pi_nnz_5": nnz_5, "pi_nnz_6": nnz_6})

                plot_fn = f'epoch{epoch}_batch{step}'
                plot_tit = f' epoch:{epoch} batch:{step}'
                plot.plot_heatmap(model_dir, z_list[0].detach().cpu().numpy(), y_np, args.n_clusters, plot_fn,
                                  title='ZtZ' + plot_tit)
                plot_membership(model_dir, Pi_np, y_np, plot_fn + '_1', title=f'Pi' + plot_tit, vmax=0.1)
                plot_membership(model_dir, Pi_np, y_np, plot_fn + '_2', title=f'Pi' + plot_tit, vmax=0.01)
                plot_membership(model_dir, Pi_np, y_np, plot_fn + '_3', title=f'Pi' + plot_tit, vmax=0.001)

                plot_3d_my(model_dir, z_avg.detach().cpu().numpy(), y_np, f'epoch{epoch}_batch{step}')


#                 if args.n_views>1:
#                     z_avg = chunk_avg(z,n_chunks=args.n_views,normalize=True)

            rank_lst = compute_numerical_rank(z_avg.detach().cpu().numpy(), y.detach().cpu().numpy(),
                                              args.n_clusters, tau=0.95)
            wandb.log({f'z_rank_class{c - 1}': r for c, r in enumerate(rank_lst)})

            
            loss, loss_list= criterion(z_avg,Pi,num_classes=n_samples,detach=False)
            loss_reg = args.pigam * 0.5 * Pi.norm()**2
            if epoch <= args.selfsupiter:
                loss = -loss_list[0] - args.z_weight*z_sim
            else:
                loss = loss + loss_reg
#         loss = loss - args.z_weight*z_sim + loss_reg
        

#         loss_list += [z_sim.item()]

        optimizer.zero_grad()
        optimizerc.zero_grad()
        if args.fp16:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizerc.step()
        else:
            loss.backward()
            optimizer.step()
            optimizerc.step()

        wandb.log({"loss_dR": -loss_list[0] + loss_list[1], "loss_R": loss_list[0],
                   "loss_Rc": loss_list[1], "loss_reg": loss_reg.item(),
                   "loss_all": loss.item()})

print("training complete.")
#%%

z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 0.735   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch0_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 0.745   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch1_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 0.745   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch2_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 0.742   ||| stdev: 0.0046
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch3_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 0.735   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_1

z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch75_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch80_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])

Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch160_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch165_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/se

z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch250_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch255_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200

z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch340_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch345_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200

Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch425_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/selfsup_MLP3D_100_synthetic_se_synthetic_kRKIoanQyz/figures/heatmaps/heatmat_epoch430_batch0.png
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
z.shape: torch.Size([400, 3])
Pi.shape: torch.Size([200, 200])
Acc mean: 1.000   ||| stdev: 0.0000
Plot saved to: ./exps/se