In [None]:
from sklearn.manifold import TSNE
from model import MyModel
from build import build_dataset
from dataset import PrefixDataset1

import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
import argparse

In [None]:
parser = argparse.ArgumentParser()

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', type=float, default=0.0003, help='Learning rate')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size')
    parser.add_argument('--epochs', type=int, default=100, help='train epochs')
    parser.add_argument('--milestones', type=int, nargs='+', default=[116, 233], help='Milestones')
    parser.add_argument('--gamma', type=float, default=0.1, help='Gamma')
    # parser.add_argument('--optimizer', type=str, default='sgd', help='optimizer')

    parser.add_argument('--voc_len',type=int, default=42020, help='voc number')
    parser.add_argument('--embedding_dim',type=int, default=1024, help='embedding size')
    parser.add_argument('--output_dim', type=int, default=64, help="output dim")
    parser.add_argument('--dstore_mmap',type=str, default='/data/zqh/NLP/adaptive-knn-mt/store/datastore/it_finetune')
    parser.add_argument('--dstore_size',type=int, default=3608731, help='datastore size')
    parser.add_argument('--use_cluster', type=bool, default=True, help="if use word cluster")
    parser.add_argument('--cluster_type', type=str, default='spectrum', help='cluster type')
    
    # contrastive learning
    parser.add_argument('--K', type=int, default=500, help='queue size')
    parser.add_argument('--m', type=float, default=0.999, help='momentum')
    parser.add_argument('--class_num', type=int, default=42020, help="class number")
    

    # save
    parser.add_argument('--save_path', type=str, default='/data/zqh/adaptive-knn-mt/checkpoints/koran', help='save checkpoint dir')
    # dataset
    args = parser.parse_args([])
    return args

In [None]:
args = get_args()
dataset= PrefixDataset1(args=args)

from sklearn.cluster import SpectralClustering
from sklearn.cluster import DBSCAN
import math

choice_label = np.arange(args.voc_len)
np.random.shuffle(choice_label)

real_choice_label = []
for i in choice_label:
    if len(dataset.label[dataset.label==i]) >= 200:
        real_choice_label.append(i)
    if len(real_choice_label) == 2:
        break

labels = None
embedding = None
for j, i in enumerate(real_choice_label):
    temp_embedding = dataset.data[dataset.label==i]
    temp_labels = dataset.label[dataset.label==i]
    choice_sample = np.arange(temp_embedding.shape[0])
    np.random.shuffle(choice_sample)
    number = min(200, temp_embedding.shape[0])
    choice_sample = choice_sample[:]
    
    if labels is None:
        labels = np.full(choice_sample.shape, j) 
        embedding = temp_embedding[choice_sample]
    else:
        labels = np.concatenate((labels, np.full(choice_sample.shape, j)))
        embedding = np.concatenate((embedding, temp_embedding[choice_sample]))


In [None]:
tsne = TSNE(n_components=2)
tsne_embedding = tsne.fit_transform(embedding)
plt.scatter(tsne_embedding[:,0], 
tsne_embedding[:,1], c=labels)
plt.show()

In [None]:
from torch.utils.data import Dataset

class EmbeddingDataset(Dataset):
    def __init__(self, args):
        super().__init__()
        self.args = args

        self.data = embedding
        self.labels = labels
        self.data = np.array(self.data)
        self.labels = np.array(self.labels)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        embedding = self.data[index]
        pos_index = None
        for i, label in enumerate(self.labels):
            if label == self.labels[index]:
                pos_index = i
                break
        embedding_1 = self.data[pos_index]

        label = self.labels[index]
        return embedding, embedding_1, label

In [None]:
from torch.utils.data import DataLoader

mymodel = MyModel(args).cuda()

dataset = EmbeddingDataset(args)

dataloader = DataLoader(
    dataset = dataset,
    batch_size = args.batch_size,
    shuffle = True
)

optimizer = torch.optim.SGD(mymodel.parameters(), args.lr, 
                                         momentum=0.9, nesterov=True,
                                         weight_decay=0.0004)
# self.optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)


for epoch in range(args.epochs):
    correct = 0
    data_len = 0
    for i, (x, x_key, label) in enumerate(dataloader):
        x = x.cuda()
        x_key = x_key.cuda()
        label = label.cuda().long()

        logits, loss, _, _ = mymodel(x, x_key, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        predictions = logits.argmax(dim=-1, keepdim=True)

        if i % 50 == 0:
            batch_correct = predictions.eq(label.view_as(predictions)).sum().item()
            acc = batch_correct / x.shape[0]

            print(f"Train epoch: {epoch} loss: {loss} acc: {acc}")

In [None]:
tsne = TSNE(n_components=2)
con_embedding = None
con_label = None
for i, (x, x_key, label) in enumerate(dataloader):
    x = x.cuda()
    hidden = mymodel.encode(x)
    if con_embedding is None:
        con_embedding = hidden
        con_label = label
    else:
        con_embedding = torch.cat((con_embedding, hidden))
        con_label = torch.cat((con_label, label))
con_embedding = con_embedding.cpu().detach().numpy()
con_label = con_label.cpu().detach().numpy()

tsne_embedding = tsne.fit_transform(con_embedding)
plt.scatter(tsne_embedding[:,0], 
tsne_embedding[:,1], c=con_label)
plt.show()
