In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)


from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('To enable a high-RAM runtime, select the Runtime > "Change runtime type"')
  print('menu, and then select High-RAM in the Runtime shape dropdown. Then, ')
  print('re-execute this cell.')
  
else:
  print('You are using a high-RAM runtime!')

In [None]:

from google.colab import drive
drive.mount('/content/drive')

import os

import torch
os.chdir('/content/drive/MyDrive/SampledSoftmaxSelf/LSTM')


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device

## RNN Model

In [None]:

import torch.nn as nn


class RNNModel2(nn.Module):

    def __init__(self, ntoken, ninp, nhid, nout, nlayers, dropout=0.5):
        super(RNNModel2, self).__init__()

        self.nhid = nhid
        self.nlayers = nlayers

        # encoder -> dropout
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp) # Token2Embeddings
        
        # lstm
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout) #(seq_len, batch_size, emb_size)

        self.init_weights() # initialize weights in encoder


    def init_weights(self):
        initrange = 0.05
        self.encoder.weight.data.uniform_(-initrange, initrange)


    def forward(self, emb, hidden):
        emb = self.drop(self.encoder(emb))
        
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        
        return output.view(output.size(0)*output.size(1), output.size(2)), hidden

    
    def init_hidden(self, bsz):
        # LSTM h and c
        weight = next(self.parameters()).data
        return weight.new_zeros(self.nlayers, bsz, self.nhid), weight.new_zeros(self.nlayers, bsz, self.nhid)


# LATM + Sampled Softmax

In [None]:
import argparse
import time
import math
import torch
import torch.nn as nn
import corpus
import easydict
import torch.nn.functional as F
from utils import *
import torch.optim as optim

args = easydict.EasyDict({
  "checkpoint": '', 
  "data": './input',
  "emsize": 512,
  "nhid": 512,
  "nlayers": 2,
  "lr": 20,
  "clip": 0.35,
  "epochs": 50,
  "batch_size": 20,
  "bptt": 35,
  "dropout": 0.5,
  "save": './output/model.pt',
  "opt": "Adam",
  "softmax_nsampled": 40,
  "method": 'RandomFeature',
  "sub_method": "FAVOR",
  "rf_D": 1024,
  "normalize_phi": True,
  "tau": None,
  "tau_hc": 1,
  "dataset": 'ptb'
  
})



torch.manual_seed(1111)
torch.cuda.manual_seed(1111)


# Load data
corpus = corpus.Corpus(args.data, args.dataset)


def batchify(data, bsz):
    nbatch = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch * bsz)
    data = data.view(bsz, -1).t().contiguous()
    return data


eval_batch_size = 10
train_data = batchify(corpus.train, args.batch_size) # size(total_len//bsz, bsz)
val_data = batchify(corpus.valid, eval_batch_size)
test_data = batchify(corpus.test, eval_batch_size)


# Build the model
interval = 700 # interval to report
ntokens = len(corpus.dictionary) # 10000



## RandomFeature Sampler

In [None]:

class RandomFeatureSampler(object):
  
    def __init__(self, rf_D, nhid, sub_method):

        self.D = rf_D
        self.sub_method = sub_method
        
        if self.sub_method == 'RFF':
            self.orth = torch.randn(rf_D, nhid, dtype = torch.float32).to(device)
        elif self.sub_method == 'FAVOR':
            self.orth = torch.randn(rf_D, nhid, dtype = torch.float32).to(device)



    # for sampled full-softmax, rff, and favor
    def generate_phi(self,  orth_emb):
        if self.sub_method == 'RFF':
            phi =  np.sqrt(1/args.rf_D)* torch.hstack( (torch.cos(orth_emb), torch.sin(orth_emb)) )
        elif self.sub_method == 'FAVOR':
            phi =  np.sqrt(1/(2*args.rf_D))* torch.hstack( (torch.exp(orth_emb), torch.exp(-orth_emb)) )

        return phi




    def sample(self, inputs_emb, class_emb_weight, labels, num_samples):
        
        num_datapoints = inputs_emb.shape[0]
        
        if (self.sub_method == 'RFF') or (self.sub_method == 'FAVOR'):
            inputs_emb = torch.Tensor(inputs_emb).to(device)
            class_emb_weight = torch.Tensor(class_emb_weight).to(device)


            orth_inputs_emb = torch.matmul(inputs_emb, self.orth.T)
            phi_h = self.generate_phi(orth_inputs_emb)

            orth_class_emb_weight = torch.matmul(class_emb_weight, self.orth.T)
            phi_c = self.generate_phi(orth_class_emb_weight)


        if (self.sub_method == 'RFF') or (self.sub_method == 'FAVOR'):
            # first calculate un-normalized q_matrix
            q_matrix = torch.matmul(phi_h, phi_c.T) #/ torch.sum(torch.matmul(phi_h, phi_c.T), axis = 1).unsqueeze(1)
            q_matrix[q_matrix<0] = 0
            true_class_emb = torch.index_select(phi_c, 0, labels)
            true_qs = torch.sum(phi_h * true_class_emb, axis = 1) #/ torch.matmul(phi_h, torch.sum(phi_c, axis = 0))
            true_qs /= torch.sum(q_matrix, axis = 1) # normalized true_qs. to real probability
            q_matrix = (q_matrix  / torch.sum(q_matrix, axis = 1).unsqueeze(1)) #normalize q_matrix to real probability
       

        sampled_ids = torch.multinomial(q_matrix, num_samples, replacement=True)#.to(device)
        sampled_qs = torch.gather(q_matrix, 1, sampled_ids)

            
        return sampled_ids, true_qs, sampled_qs 





## Sampled Softmax

In [None]:


class SampledSoftmax(nn.Module):
    def __init__(self, ntokens, nsampled, nhid):
        super(SampledSoftmax, self).__init__()

        # Parameters
        self.ntokens = ntokens
        self.nsampled = nsampled
        self.method = args.method

        self.sampler = RandomFeatureSampler(args.rf_D, nhid, args.sub_method)
            
        self.class_emb = (nn.Embedding(ntokens, nhid)) #.to(device) # size of [d, n] 
        
        self.init_weights()

        


    def init_weights(self):
        initrange = 0.12
        self.class_emb.weight.data.uniform_(-initrange, initrange)


    def forward(self, inputs_emb, labels):


        if args.normalize_phi == True:
            normalized_class_emb_weight_data = args.tau_hc * F.normalize(sampled_softmax.class_emb.weight, p=2, dim=1).cpu().detach().numpy()
            normalized_inputs_emb = args.tau_hc * F.normalize(inputs_emb, p=2, dim=1).cpu().detach().numpy()

        if self.training:
            if args.normalize_phi == True:
                sample_values = self.sampler.sample(normalized_inputs_emb, normalized_class_emb_weight_data, labels, self.nsampled)
            return self.sampled(inputs_emb, labels, sample_values, remove_accidental_match=True)

        else:
            return self.full(inputs_emb)
           



    def sampled(self, inputs_emb, labels, sample_values, remove_accidental_match=False):

        batch_size, d = inputs_emb.size()
        sample_ids, true_freq, sample_freq = sample_values

        sample_ids = sample_ids.to(device)
        true_freq = true_freq.to(device)
        sample_freq = sample_freq.to(device)

        # true class embedding is normalized, but true logits is scaled with args.tau
        true_class_emb = torch.index_select(self.class_emb.weight, 0, labels)

        true_logits =  torch.sum(torch.mul(inputs_emb, true_class_emb), dim=1)
        sample_logits = torch.gather(torch.matmul(inputs_emb, self.class_emb.weight.T), 1, sample_ids)

        # remove true labels from sample set
        if remove_accidental_match:
            sample_logits += (labels.unsqueeze(1) == sample_ids)*(-1e37)

        # perform correction
        true_logits = true_logits.sub(torch.log(true_freq*self.nsampled))
        sample_logits = sample_logits.sub(torch.log(sample_freq*self.nsampled))

        # return logits and new_labels
        logits = torch.cat((torch.unsqueeze(true_logits, dim=1), sample_logits), dim=1)
        new_targets = Variable(torch.zeros(batch_size).long()).to(device)

        return logits, new_targets

    def full(self, inputs_emb):
        return torch.matmul(inputs_emb, self.class_emb.weight.T) # of size [700, 10000]





## train & val

In [None]:

def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    with torch.no_grad():
        model.eval()
        total_loss = 0
        hidden = model.init_hidden(eval_batch_size)
        for i in range(0, data_source.size(0) - 1, args.bptt):# iterate over every timestep
            data, targets = get_batch(data_source, i, args.bptt)
            data, targets = data.to(device), targets.to(device)
            
            hidden = repackage_hidden(hidden,  device) #####?????
            
            # run RNN model
            output, hidden = model(data, hidden)

            # run decoder
            if args.method == 'Full':
                logits = decoder(output)
            else:
                logits = sampled_softmax.full(output)
            
            
            total_loss += len(data) * criterion(logits, targets).data
            
        return total_loss.item() / len(data_source)






def train():

    model.train()
    total_loss = 0
    #start_time = time.time()
    softmax_time_total = 0
    hidden = model.init_hidden(args.batch_size)

    for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
        data, targets = get_batch(train_data, i, args.bptt)
        data, targets = data.to(device), targets.to(device)

        hidden = repackage_hidden(hidden,  device)
        
        # run RNN model
        output, hidden = model(data, hidden)

        start_time = time.time() ### start time

        if args.method == 'Full':
            logits = decoder(output)
        else:
            logits, new_targets = sampled_softmax(output, targets)

        softmax_time_total += time.time() - start_time ### end time 

        optimizer.zero_grad()

        
        # loss
        if args.method == 'Full':
            loss = criterion(logits, targets)
        else:
            loss = criterion(logits, new_targets)
            


        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        optimizer.step()

        total_loss += loss.data

        if batch % interval == 0 and batch > 0:
            cur_loss = total_loss / interval
            
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.4f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // args.bptt, args.lr,
                softmax_time_total * 1000 / interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            softmax_time_total = 0




## FAVOR+ -- tau_hc = 1, rf_D = 1024, m = 40 10runs

In [None]:
args.sub_method = 'FAVOR'
args.tau_hc =1
args.rf_D = 1024
args.softmax_nsampled = 40

In [None]:
args

In [None]:
for run in range(10):

    model = RNNModel2(ntokens, args.emsize, args.nhid, args.emsize, args.nlayers,  args.dropout).to(device)

    # Load checkpoint
    if args.checkpoint != '':
        model = torch.load(args.checkpoint, map_location=lambda storage, loc: storage)

    print(args)

    sampled_softmax = SampledSoftmax(ntokens = ntokens, nsampled = args.softmax_nsampled, nhid = args.nhid).to(device)
    model.add_module("decoder", sampled_softmax)

    model.cuda()

    # Loop over epochs.
    lr = args.lr
    best_val_loss = None

    args.opt = 'Adam'

    if args.opt == 'SGD':
        lr, args.lr = 20, 20
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    if args.opt == 'Adam':
        lr, args.lr = 0.001, 0.001
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99))
    if args.opt == 'Momentum':
        lr, args.lr = 20, 20
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.8)
    if args.opt == 'RMSprop':
        lr, args.lr = 0.001, 0.001
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, alpha=0.9)
    if args.opt == 'Adagrad':
        lr, args.lr = 20, 20
        optimizer = torch.optim.Adagrad(model.parameters(), args.lr, weight_decay=1e-5)


    criterion = nn.CrossEntropyLoss()
    print(model)


    try:

        best_val_loss = None
        val_loss_list = []
        val_perp_list = []


        print("optimizer is:", args.opt)

        for epoch in range(1, args.epochs+1):
            epoch_start_time = time.time()
            train()
            val_loss = evaluate(val_data)
            val_loss_list.append(val_loss)
            val_perp_list.append(math.exp(val_loss))
            print('-' * 89)
            print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss)))
            print('-' * 89)
            # Save the model if the validation loss is the best we've seen so far.
            if not best_val_loss or val_loss < best_val_loss:
                with open(args.save, 'wb') as f:
                    torch.save(model, f)
                best_val_loss = val_loss
            else:
                print("########### ?????? ############")
                # Anneal the learning rate if no improvement has been seen in the validation dataset.
                if args.opt == 'SGD' or args.opt == 'Momentum':
                    args.lr /= 4.0
                    lr = args.lr
                    for group in optimizer.param_groups:
                        group['lr'] = args.lr
                        lr = args.lr
        

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')

    # Load the best saved model.
    with open(args.save, 'rb') as f:
        model_saved = torch.load(f)

    model = model_saved

    # Run on test data.
    test_loss = evaluate(test_data)
    print('=' * 89)
    print('| End of training | test loss {:5.2f} | test perplexity {:8.2f}'.format( test_loss, math.exp(test_loss)))
    print('=' * 89)


    #with open('./output/FAVOR_tauhc='+str(args.tau_hc)+'_rf_D='+str(args.rf_D)+'_m='+str(args.softmax_nsampled)+'_50epochs_run'+str(run)+'.npy', 'wb') as f:
    #    np.save(f, np.array(val_loss_list))