In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import time
import psutil

In [None]:
# Constants
INPUT_DIM = 768
HIDDEN_DIM = 256
OUTPUT_DIM = 33
NUM_LAYERS = 2
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_EPOCHS = 50
DROPOUT = 0.5

In [None]:
class EmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.tensor(self.embeddings[idx], dtype=torch.float32), self.labels[idx]

In [None]:
class LSTMModel(nn.Module):
    def __init__(self):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(INPUT_DIM, HIDDEN_DIM, NUM_LAYERS, dropout=DROPOUT, batch_first=True)
        self.fc = nn.Linear(HIDDEN_DIM, OUTPUT_DIM)

    def forward(self, x):
        output, _ = self.lstm(x)
        return self.fc(output[:, -1, :])

In [None]:
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

In [None]:
def train(rank, world_size):
    setup(rank, world_size)
    
    # Load and prepare data
    train_data = pd.read_hdf('train_embeddings.h5')
    test_data = pd.read_hdf('test_embeddings.h5')
    val_data = pd.read_hdf('val_embeddings.h5')

    label_encoder = LabelEncoder()
    train_data['source'] = label_encoder.fit_transform(train_data['source'])
    test_data['source'] = label_encoder.transform(test_data['source'])
    val_data['source'] = label_encoder.transform(val_data['source'])

    train_dataset = EmbeddingDataset(np.array(train_data['gpt2_embeddings']), np.array(train_data['source']))
    test_dataset = EmbeddingDataset(np.array(test_data['gpt2_embeddings']), np.array(test_data['source']))
    val_dataset = EmbeddingDataset(np.array(val_data['gpt2_embeddings']), np.array(val_data['source']))

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, sampler=test_sampler)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)

    model = LSTMModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    optimizer = optim.Adam(ddp_model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(NUM_EPOCHS):
        ddp_model.train()
        epoch_start_time = time.time()
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(rank), target.to(rank)
            optimizer.zero_grad()
            output = ddp_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Calculate time and resource usage
        epoch_time = time.time() - epoch_start_time
        gpu_memory = torch.cuda.memory_allocated(rank) / (1024 ** 3)
        system_memory = psutil.virtual_memory().used / (1024 ** 3)

        # Evaluate on validation set
        ddp_model.eval()
        val_preds, val_labels = [], []
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(rank), target.to(rank)
                output = ddp_model(data)
                val_preds.extend(output.argmax(dim=1).cpu().numpy())
                val_labels.extend(target.cpu().numpy())

        accuracy = accuracy_score(val_labels, val_preds)
        precision = precision_score(val_labels, val_preds, average='macro')
        recall = recall_score(val_labels, val_preds, average='macro')
        f1 = f1_score(val_labels, val_preds, average='macro')

        if rank == 0:
            print(f'Epoch {epoch+1}: Loss={total_loss/len(train_loader)}, Time={epoch_time}, GPU Mem={gpu_memory}GB, Sys Mem={system_memory}GB')
            print(f'Val Metrics - Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1: {f1}')

    cleanup()

In [None]:
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)