In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import copy
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
from ctm_dataloader import create_dataloader

In [3]:
news_data = pd.read_csv("newsgroups_data.csv")
print("Shape of dataset:", news_data.shape)

news_data = news_data.drop(columns=["Unnamed: 0"])

documents = news_data.content
target_labels = news_data.target
target_names = news_data.target_names

news_data.head()

Shape of dataset: (11314, 4)


Unnamed: 0,content,target,target_names
0,From: lerxst@wam.umd.edu (where's my thing)\nS...,7,rec.autos
1,From: guykuo@carson.u.washington.edu (Guy Kuo)...,4,comp.sys.mac.hardware
2,From: twillis@ec.ecn.purdue.edu (Thomas E Will...,4,comp.sys.mac.hardware
3,From: jgreen@amber (Joe Green)\nSubject: Re: W...,1,comp.graphics
4,From: jcm@head-cfa.harvard.edu (Jonathan McDow...,14,sci.space


In [4]:
# Create dataloader
batch_size = 32
train_loader, vocab_size = create_dataloader(documents, batch_size)

for idx, data in enumerate(train_loader):
    print("Shape of data:", data.shape)
    data_df = pd.DataFrame(data.numpy())
    data_df.head()
    break

Printing vocabulary:
['also' 'article' 'believe' 'call' 'come' 'drive' 'even' 'file' 'find'
 'first']
Shape of the bag-of-words model: (11314, 50)
Shape of data: torch.Size([32, 50])


In [5]:
# Sufficent statistics class for M-step

class SufficientStats():
    """
    Stores statistics about variational parameters during E-step in order
    to update CtmModel's parameters in M-step.

    `self.mu_stats` contains sum(lamda_d)

    `self.sigma_stats` contains sum(I_nu^2 + lamda_d * lamda^T)

    `self.beta_stats[i]` contains sum(phi[d, i] * n_d) where nd is the vector
    of word counts for document d.

    `self.numtopics` contains the number of documents the statistics are build on

    """

    def __init__(self, num_topics, vocab_size):
        self.num_docs = 0
        self.num_topics = num_topics
        self.vocab_size = vocab_size
        self.beta_stats = torch.zeros([num_topics, vocab_size])
        self.mu_stats = torch.zeros(num_topics)
        self.sigma_stats = torch.zeros([num_topics, num_topics])

    def update(self, lamda, nu2, phi, doc):
        """
        Given optimized variational parameters, update statistics

        """

        # update mu_stats
        self.mu_stats += lamda

        # update \beta_stats[i], 0 < i < self.numtopics
        for n, c in enumerate(doc):
            for i in range(self.num_topics):
                self.beta_stats[i, n] += c * phi[n, i]
                
        # update \sigma_stats
        self.sigma_stats += torch.diag(nu2) + torch.matmul(lamda, torch.t(lamda))

        self.num_docs += 1

In [6]:
num_topics = 10
vocab_size = 100
lamda = torch.randn(num_topics)
nusqrd = torch.ones(num_topics)

torch.sum(torch.exp(torch.add(lamda, nusqrd, alpha = 0.5)))

tensor(19.6484)

In [7]:
# CTM model class - contains all the parameters and methods for the model

class CTM():
    
    def __init__(self, num_topics, vocab_size, num_epochs, batch_size, lr,
                 tol_em = 1e-6, tol_es = 1e-5, max_iter = 10):
        super(CTM, self).__init__()

        self.num_topics = num_topics # k - number of topics
        self.vocab_size = vocab_size # V - size of vocabulary
        self.tol_em = tol_em # relative change we need to achieve in E-step
        self.tol_es = tol_es # relative change we need to achieve in Expectation-Maximization
        self.max_iter = max_iter # maximum number of iterations
        self.num_epochs = num_epochs # number of epochs
        self.batch_size = batch_size # batch size
        self.lr = lr # learning rate

        # Model parameters
        self.mu = torch.zeros(self.num_topics) #k-dimensional mean
        self.sigma = torch.eye(self.num_topics) # k x k symmetrical matrix - covariance matrix
        self.sigma_inv = torch.inverse(self.sigma) # k x k symmetrical matrix - inverse of covariance matrix
        self.beta = torch.ones(num_topics, vocab_size) # k x V matrix - global topic-word distribution
        self.theta = torch.rand(num_topics) # k-dimensional vector - document-topic distribution

        # Variantional parameters
        self.lamda = nn.Parameter(torch.zeros(self.num_topics)) # k-dimensional vector - gaussian mean
        self.nusqrd = nn.Parameter(torch.ones(self.num_topics)) # k-dimensional vector - gaussian variance
        self.phi = nn.Parameter(1/float(self.num_topics) * torch.ones(self.vocab_size, self.num_topics)) # V x k matrix - local topic-word distribution
        self.zeta = nn.Parameter(torch.sum(torch.exp(torch.add(self.lamda, self.nusqrd, alpha = 0.5)))) # scalar - variational parameter
        
        #Debug check
        if(0 in self.beta):
            print("There is a zero in beta 1")
        
    def train(self, corpus):
        # Two phases, E and then M, where we try to maximize bound on log prob
        # Repeat until difference is less than TOL_EM
        
        for i in range(self.max_iter):
            print(f'----- Training epoch {i+1} of {self.max_iter} -----')
            
            old_bound = self.corpus_bound(corpus)
            print(f'Initial bound at the start of train_loop: {old_bound}')
            
            statistics = self.e_phase(corpus)
            self.m_phase(statistics)

            delta = (self.corpus_bound(corpus) - old_bound)/old_bound

            print("delta",delta)
            print("non-normalized delta", (delta * old_bound))

            # max just to make sure delta is positive, could use abs
            if max(-delta, delta) < self.tol_em:
                print('Training converged')
                writer.close()
                break
            
    def log_bound(self, doc, zeta = None, lamda = None, nusqrd = None, phi = None):
        """ Bound the log probability of a document using Jensen's inequality """
        if lamda is None:
            lamda = self.lamda
        if nusqrd is None:
            nusqrd = self.nusqrd
        if phi is None:
            phi = self.phi
        if zeta is None:
            zeta = self.zeta
            
        N = sum(doc)
        bound  = 0.0
        
        # Part 1 of equation 4
        bound += 0.5 * torch.log(torch.linalg.det(self.sigma_inv))
        bound += -0.5 * self.num_topics * np.log(2*np.pi)
        bound += -0.5 * torch.trace(torch.diag(nusqrd) @ self.sigma_inv) \
            + torch.t(lamda - self.mu) @ self.sigma_inv @ (lamda - self.mu)
           
        # Part 2 partial of equation 4 
        """Eq [log p(zn | η)]"""
        expect = torch.sum(torch.exp(lamda + 0.5 * nusqrd)).item()
        bound += (N * (-1/zeta * expect + 1 - torch.log(zeta))).item()

        # Part 4 partial of equation 4
        """Entropy term"""
        sum_1 = 0
        for topic in range(self.num_topics):
            sum_1 += torch.log(nusqrd[topic]) + np.log(2 * np.pi) + 1
            
        bound += sum_1.item()
        
        # Part 2 partial, 3 and 4 partial of equation 4
        for word, count in enumerate(doc):
            for topic in range(self.num_topics):
                bound += (count*phi[word,topic] * (lamda[topic] + torch.log(self.beta[topic,word]) - torch.log(phi[word,topic]))).item()
                
                # debugging check 1
                if(torch.tensor(bound).isnan()):
                    print("Bound is Nan")
                    print("first: ", count*phi[word,topic])
                    print("second: ", (lamda[topic] + torch.log(self.beta[topic,word]) - torch.log(phi[word,topic])))
                    print("beta: ", self.beta[topic,word])
                    print("beta log: ", torch.log(self.beta[topic,word]))
                    print("phi log: ", torch.log(phi[word,topic]))
                    exit()
                    
                # debugging check 2
                if((count*phi[word,topic]).isnan()):
                    print("first isNan")
                    print("phi", phi)
                    print("element",phi[word,topic])
                
                # debugging check 3
                if ((lamda[topic] + torch.log(self.beta[topic,word]) - torch.log(phi[word,topic]))).isnan():
                    print("second isNan")
                    print(lamda[topic])
                    print(self.beta[topic,word])
                    print(phi[word,topic])
                    exit()
        
        return -bound
    
    def corpus_bound(self, corpus):
        """ Compute the bound on the log probability of the corpus """
        return sum([self.log_bound(doc) for doc in corpus])
            
    def e_phase(self, corpus):
        """Coordinate ascent optimization of the variational parameters"""
        print('Starting E-phase')
        statistics = SufficientStats(self.num_topics, self.vocab_size)
        
        for i, doc in enumerate(corpus):
            if i % 8 == 0:
                print(f'Processing document {i+1} to {i+8} of {len(corpus)} documents')
            
            model = copy.deepcopy(self)
            model.variational_inference(doc)
            
            statistics.update(self.lamda, self.nusqrd, self.phi, doc)
        print(f'Beta at the end of e_phase: {self.beta.shape}')
        #Debug check
        if(0 in self.beta):
            print("THERE IS A ZERO IN BETA 1")
            
        return statistics
    
    def variational_inference(self, doc):
        # bound = self.log_bound(doc)
        # new_bound = bound
        
        def target(doc, zeta, lamda, nusqrd, phi):
            return self.log_bound(doc, zeta = zeta, lamda = lamda, nusqrd = nusqrd, phi = phi)
        
        optimizer = torch.optim.Adam([self.zeta, self.lamda, self.nusqrd, self.phi], lr = self.lr)
        
        # for _ in range(self.max_iter):
        #     self.max_zeta()
        #     self.max_phi(doc)
        #     self.max_lamda(doc, self.lamda, self.phi)
        #     self.max_nusqrd(doc, self.nusqrd)
            
        for epoch in range(self.num_epochs):
            loss = target(doc, self.zeta, self.lamda, self.nusqrd, self.phi)
            writer.add_scalar("Loss/train", loss, epoch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        writer.flush()
            
            # bound, new_bound = new_bound, self.log_bound(doc)
            # relative_change = abs((new_bound - bound) / bound)
            
            # if (relative_change < self.tol_es):
            #     break
        
    # def e_phase_part1(self):
    #     """Eq [log p(η | µ, Σ)]"""
    #     part_1 = 0.5 * torch.log(torch.linalg.det(self.sigma_inv))
    #     part_2 = -0.5 * self.num_topics * np.log(2*np.pi)
    #     part_3 = -0.5 * torch.trace(torch.diag(self.nusqrd) @ self.sigma_inv) \
    #         + torch.t(self.lamda - self.mu) @ self.sigma_inv @ (self.lamda - self.mu)
    #     return part_1 + part_2 + part_3
    
    # def e_phase_part2(self, doc):
    #     """Eq [log p(zn | η)]"""
    #     part_1 = 0.0
    #     for n,c in enumerate(doc):
    #         for i in range(self.num_topics):
    #             part_1 += c*self.phi[n,i] * self.lamda[i]
    #     part_2 = -(1/self.zeta) * torch.sum(torch.exp(self.lamda + 0.5*self.nusqrd)) + 1 - torch.log(self.zeta)
    #     return part_1 + part_2
    
    # def e_phase_part3(self, doc):
    #     """Eq [log p(wn | zn, β)]"""
    #     sum = 0
    #     for n,c in enumerate(doc):
    #         for i in range(self.num_topics):
    #             sum += c * torch.sum(self.phi[n,i] * torch.log(self.beta[i,n]))
    #     return sum
    
    # def e_phase_part4(self):
    #     sum_1 = 0
    #     for topic in range(self.num_topics):
    #         sum_1 += torch.log(self.nusqrd[topic]) + np.log(2 * np.pi) + 1
    #     sum_2 = 0
    #     for word in range(self.vocab_size):
    #         for topic in range(self.num_topics):
    #             sum_2 += self.phi[word, topic] * torch.log(self.phi[word, topic])
    #     return 0.5 * sum_1 - sum_2
    
    def max_zeta(self):
        """Maximize zeta"""
        return torch.sum(torch.exp(torch.add(self.lamda, self.nusqrd, alpha = 0.5)))
    
    def max_phi(self, corpus):
        phi_norm = 0.0
        for n, _ in enumerate(corpus):
            for i in range(self.num_topics):
                phi_norm = sum([torch.exp(self.lamda[i]) * self.beta[i,n]])
            for i in range(self.num_topics):
                self.phi[n, i] = torch.exp(self.lamda[i]) * self.beta[i, n] / phi_norm
                
    def max_lamda(self, doc, lamda, phi):
        
        def target(doc, lamda, phi):
            return self.log_bound(doc, lamda = lamda, phi = phi)
        
        optimizer = torch.optim.Adam([doc, lamda], maximize = True)
        
        for _ in range(10):
            optimizer.zero_grad()
            loss = target(doc, lamda, phi)
            # loss.backward()
            optimizer.step()
            
    
    def max_nusqrd(self, doc, nusqrd):
        
        def target(doc, nusqrd):
            return self.log_bound(doc, nusqrd = nusqrd)
        
        optimizer = torch.optim.Adam([doc, nusqrd], maximize = True)
        
        for _ in range(10):
            optimizer.zero_grad()
            loss = target(doc, nusqrd)
            # loss.backward()
            optimizer.step()
            
    def m_phase(self, sstats):
        """Maximize the parameters of the model"""
        print('Starting M-phase')
        if(0 in sstats.beta_stats):
            print("MPHASE ZERO")
        
        for i in range(self.num_topics):
            beta_norm = torch.sum(sstats.beta_stats[i])
            self.beta[i] = sstats.beta_stats[i] / beta_norm
        
        self.mu = sstats.mu_stats / sstats.num_docs
        self.sigma = sstats.sigma_stats + torch.matmul(self.mu, torch.t(self.mu))
        self.sigma_inv = torch.inverse(self.sigma)
        #Debug check
        if(0 in self.beta):
            print("THERE IS A ZERO IN BETA 1")

#num_topic is the number of topics to find
def runModel(train_loader, num_topics, num_epochs, batch_size, lr):

    # Create dataloader
    # train_loader, vocab_size = create_dataloader(data, batch_size)

    torch.set_printoptions(threshold=20000) #lets us print larger tensors to terminal for debugging

    model = CTM(num_topics, vocab_size, num_epochs, batch_size, lr)
    
    for i, batch in enumerate(train_loader):
        print(f'--------- Training batch {i+1} ---------')
        model.train(batch)
        
    print(f'Final probabilies: {model.beta}')
    
    # save the model.beta
    torch.save(model.beta, "beta.pt")
    
    corrCoef = torch.corrcoef(model.beta)
    sns.heatmap(corrCoef, annot=True)
    plt.show(block=False)


In [None]:
# Hyperparameters
num_epochs = 10
num_topics = 10
batch_size = 32
lr = 1e-5

In [None]:
news_data = pd.read_csv('newsgroups_data.csv')
news_data = news_data.drop(columns=["Unnamed: 0"])

runModel(train_loader = train_loader, num_topics = num_topics, num_epochs = num_epochs, batch_size = batch_size, lr = lr)