In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter
import nltk
from nltk.tokenize import word_tokenize

# Download NLTK data
nltk.download('punkt')

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Custom Dataset
class AGNewsDataset(Dataset):
    def __init__(self, texts, labels, vocab, max_len=100):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_len = max_len
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = word_tokenize(text.lower())
        indices = [self.vocab.get(token, self.vocab['<unk>']) for token in tokens[:self.max_len]]
        # Pad or truncate
        if len(indices) < self.max_len:
            indices += [self.vocab['<pad>']] * (self.max_len - len(indices))
        else:
            indices = indices[:self.max_len]
        return torch.tensor(indices, dtype=torch.long), torch.tensor(self.labels[idx], dtype=torch.long)

# CNN Model
class TextCNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_classes, filter_sizes=[3, 4, 5], num_filters=100):
        super(TextCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.convs = nn.ModuleList([
            nn.Conv1d(embed_dim, num_filters, fs) for fs in filter_sizes
        ])
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(len(filter_sizes) * num_filters, num_classes)
    
    def forward(self, x):
        embedded = self.embedding(x).transpose(1, 2)  # (batch_size, embed_dim, seq_len)
        conv_outs = [torch.relu(conv(embedded)) for conv in self.convs]  # List of (batch_size, num_filters, *)
        pooled = [torch.max(out, dim=2)[0] for out in conv_outs]  # Max pooling: (batch_size, num_filters)
        cat = self.dropout(torch.cat(pooled, dim=1))  # (batch_size, num_filters * len(filter_sizes))
        return self.fc(cat)  # (batch_size, num_classes)

# Data Preprocessing
def load_data(file_path):
    df = pd.read_csv(file_path)
    texts = df['Description'].values
    labels = df['Class Index'].values - 1  # Convert to 0-based indexing
    return texts, labels

def build_vocab(texts, min_freq=5):
    # Tokenize and count words
    word_counts = Counter()
    for text in texts:
        tokens = word_tokenize(text.lower())
        word_counts.update(tokens)
    
    # Create vocabulary
    vocab = {'<pad>': 0, '<unk>': 1}
    vocab_idx = 2
    for word, count in word_counts.items():
        if count >= min_freq:
            vocab[word] = vocab_idx
            vocab_idx += 1
    return vocab

# Collate function for DataLoader
def collate_batch(batch):
    texts, labels = zip(*batch)
    return torch.stack(texts).to(device), torch.tensor(labels, dtype=torch.long).to(device)

# Training Loop with Metrics Collection
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    model.to(device)
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    
    for epoch in range(num_epochs):
        model.train()
        train_loss, train_acc = 0, 0
        for texts, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            optimizer.zero_grad()
            outputs = model(texts)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            train_acc += (preds == labels).float().mean().item()
        
        train_loss /= len(train_loader)
        train_acc /= len(train_loader)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # Validation
        model.eval()
        val_loss, val_acc = 0, 0
        with torch.no_grad():
            for texts, labels in val_loader:
                outputs = model(texts)
                val_loss += criterion(outputs, labels).item()
                preds = torch.argmax(outputs, dim=1)
                val_acc += (preds == labels).float().mean().item()
        
        val_loss /= len(val_loader)
        val_acc /= len(val_loader)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    return train_losses, val_losses, train_accs, val_accs

# Plotting Function
def plot_metrics(train_losses, val_losses, train_accs, val_accs, num_epochs):
    epochs = range(1, num_epochs + 1)
    
    # Plot Loss
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accs, label='Train Accuracy')
    plt.plot(epochs, val_accs, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

# Main
def main():
    # Hyperparameters
    EMBED_DIM = 100
    NUM_FILTERS = 100
    FILTER_SIZES = [3, 4, 5]
    BATCH_SIZE = 32
    NUM_EPOCHS = 25
    LEARNING_RATE = 0.001
    MAX_LEN = 100
    MIN_FREQ = 5
    
    # Load and preprocess data
    file_path = "/kaggle/input/ag-news-classification-dataset/train.csv"
    texts, labels = load_data(file_path)
    vocab = build_vocab(texts, min_freq=MIN_FREQ)
    
    # Split data
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        texts, labels, test_size=0.2, random_state=42
    )
    
    # Create datasets and loaders
    train_dataset = AGNewsDataset(train_texts, train_labels, vocab, max_len=MAX_LEN)
    val_dataset = AGNewsDataset(val_texts, val_labels, vocab, max_len=MAX_LEN)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_batch)
    
    # Initialize model
    num_classes = 4  # AG News: World, Sports, Business, Sci/Tech
    model = TextCNN(len(vocab), EMBED_DIM, num_classes, FILTER_SIZES, NUM_FILTERS)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # Train and collect metrics
    train_losses, val_losses, train_accs, val_accs = train_model(
        model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS
    )
    
    # Plot metrics
    plot_metrics(train_losses, val_losses, train_accs, val_accs, NUM_EPOCHS)
    
    # Save model
    torch.save(model.state_dict(), "ag_news_text_cnn_model.pth")

if __name__ == "__main__":
    main()