In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd

from preprocessing.preprocess_text import clean_text
from preprocessing.gen_text_embeddings import generate_text_embeddings
import torch
from pytorch_datasets.text_dataset import TextDataset
from torch.utils.data import DataLoader, random_split
from models.vae_text import TextVAE
from torch.optim import Adam

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/mehta.vats/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
cuda_device = 0 # set the index of the GPU you want to use
torch.cuda.set_device(cuda_device)
torch.backends.cudnn.benchmark = True

# set the maximum GPU memory usage
max_memory_usage = 0.95 # set the maximum memory usage as a fraction of the available memory
# torch.cuda.set_max_memory_allocated(int(torch.cuda.get_device_properties(cuda_device).total_memory * max_memory_usage))
torch.cuda.set_per_process_memory_fraction(max_memory_usage, cuda_device)

In [3]:
print(torch.cuda.max_memory_allocated(cuda_device)/1024/1024/1024)
torch.cuda.empty_cache()

0.0


In [4]:
# Load and preprocess data

instagram_data = pd.read_csv("data/instagram_data.csv")
data = instagram_data.dropna(subset=["description"]).reset_index(drop=True)

print(f"Removed {len(instagram_data) - len(data)} rows due to N/A descriptions.")

# data = data.sample(frac=0.1)

post_descriptions = (
    data["description"]
    .apply(lambda text: clean_text(text) if type(text) == str else text)
    .tolist()
)

post_classes = data["Party"].tolist()

  instagram_data = pd.read_csv("data/instagram_data.csv")


Removed 7662 rows due to N/A descriptions.


In [5]:
# Generate Text Embeddings
text_embeddings, word_index_mapping = generate_text_embeddings(post_descriptions, 32)

In [6]:
# Create datasets
dataset = TextDataset(post_descriptions, word_index_mapping, post_classes)

train_dataset, test_dataset = random_split(dataset, [0.8, 0.2])
print(len(dataset))
print(len(train_dataset))
print(len(test_dataset))

373325
298660
74665


In [None]:
# Train model

num_epochs = 10
batch_size = 32

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

small_dataset = TextDataset(post_descriptions[:1000], word_index_mapping, post_classes[:1000])
small_dataloader = DataLoader(small_dataset, batch_size=batch_size, shuffle=False)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

print(f"Using {device} device")

model = TextVAE(device, 32, 32, text_embeddings, dataset.padding_index, dataset.num_post_classes).to(device)
optimizer = Adam(model.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0

    for batch_idx, (text, party) in enumerate(train_dataloader):
        text = text.to(device)

        decoded_text, mu, logvar, classifier_result = model(text)
        
        optimizer.zero_grad()

        loss = model.loss_function(
            text, party, decoded_text, mu, logvar, classifier_result
        )
        loss.backward()

        train_loss += loss.item()
        optimizer.step()
        
        correct += (party == torch.argmax(classifier_result, dim=-1).detach().cpu()).sum().item()

    print(f"Train Epoch: {epoch+1} \tAccuracy: {(correct * 100) / len(train_dataset):.2f}% \tLoss: {train_loss / len(train_dataloader):.6f}")


Using cuda:0 device
208232
208231


In [None]:
# Test Model

model.eval()
test_loss = 0
test_correct = 0

for batch_idx, (image_embedding, party) in enumerate(test_dataloader):
    image_embedding = image_embedding.to(device)

    decoded_image, mu, logvar, classifier_result = model(image_embedding)

    loss = model.loss_function(
        image_embedding, party, decoded_image, mu, logvar, classifier_result
    )

    test_loss += loss.item()

    test_correct += (party == torch.argmax(classifier_result, dim=-1).detach().cpu()).sum().item()

print(f"Test Metrics \tAccuracy: {(test_correct * 100) / len(test_dataset):.2f}% \tLoss: {test_loss / len(test_dataloader):.6f}")
