In [1]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

device = 'cuda'

with open("./train.json", "r") as f:
    data = list(json.load(f).items())

# Character vocabulary
chars = sorted(list(set("".join([word for word, _ in data]))))
char2idx = {char: idx + 1 for idx, char in enumerate(chars)}  # 0 is reserved for padding
idx2char = {idx: char for char, idx in char2idx.items()}
vocab_size = len(chars)


# Define Dataset
class CompoundDataset(Dataset):
    def __init__(self, data, char2idx):
        self.data = data
        self.char2idx = char2idx

    def __len__(self):
        return len(self.data)

    def encode(self, word, labels):
        return (
            torch.tensor([self.char2idx[char] for char in word], dtype=torch.long),
            torch.tensor(labels, dtype=torch.float),
        )

    def __getitem__(self, idx):
        word, labels = self.data[idx]
        return self.encode(word, labels)


# Collate function to handle batching
def collate_fn(batch):
    inputs, targets = zip(*batch)
    lengths = [len(seq) for seq in inputs]
    max_len = max(lengths)

    padded_inputs = torch.zeros(len(inputs), max_len, dtype=torch.long)
    padded_targets = torch.zeros(len(targets), max_len, dtype=torch.float)

    for i, (seq, tgt) in enumerate(zip(inputs, targets)):
        padded_inputs[i, : len(seq)] = seq
        padded_targets[i, : len(tgt)] = tgt

    return padded_inputs, padded_targets, lengths


# Define the FCN Model
class FCNSegmentation(nn.Module):
    def __init__(self, vocab_size, embedding_dim=64, num_channels=128):
        super(FCNSegmentation, self).__init__()
        self.embedding = nn.Embedding(vocab_size + 1, embedding_dim, padding_idx=0)
        self.conv = nn.Sequential(
            nn.Conv1d(embedding_dim, num_channels // 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(num_channels // 2, num_channels // 4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(num_channels // 4, num_channels // 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(num_channels // 2, num_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(num_channels, 1, kernel_size=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # x: (batch_size, seq_length)
        x = self.embedding(x)  # (batch_size, seq_length, embedding_dim)
        x = x.permute(0, 2, 1)  # (batch_size, embedding_dim, seq_length)
        x = self.conv(x)  # (batch_size, 1, seq_length)
        return x.squeeze(1)  # (batch_size, seq_length)

def train():
    # Initialize Dataset and DataLoader
    dataset = CompoundDataset(data, char2idx)
    
    batch_size = 128
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4, prefetch_factor=2
    )
    
    # Initialize Model, Loss, Optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = FCNSegmentation(vocab_size).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Training Loop
    num_epochs = 32
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for inputs, targets, lengths in dataloader:
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs.to(device))
            # Mask padding positions
            mask = torch.arange(inputs.shape[1])[None, :] < torch.tensor(lengths)[:, None]
            mask = mask.to(device)
            outputs = outputs[mask]
            targets = targets[mask]
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        logging.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader):.4f}")
    return model


def predict_and_save(model, input_file, output_file, char2idx, device="cpu"):
    """
    Reads a JSON file, predicts segmentation for each word, and saves the results to a new JSON file.

    :param model: Trained model
    :param input_file: Path to input JSON file
    :param output_file: Path to output JSON file
    :param char2idx: Character to index mapping
    :param device: Device to run the model on
    """
    # Load the input data
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    # Initialize predictions dictionary
    predictions = {}

    # Set model to evaluation mode
    model.eval()

    # Predict for each word
    with torch.no_grad():
        for word, _ in data.items():
            # Convert word to indices
            indices = [char2idx.get(char, 0) for char in word]
            input_tensor = torch.tensor(indices, dtype=torch.long).unsqueeze(0).to(device)
            
            # Get model outputs
            outputs = model(input_tensor)[0].cpu().numpy()
            
            # Convert outputs to binary labels
            boundaries = (outputs > 0.6).astype(int)
            predictions[word] = boundaries.tolist()

    # Save predictions to output file
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(predictions, f, ensure_ascii=False, indent=4)

    logging.info(f"Predictions saved to {output_file}")


# Example Usage
if __name__ == "__main__":
    # Predict and save results
    model = train()
    input_file = "./test.json"
    output_file = "./submissiontest.json"
    predict_and_save(model, input_file, output_file, char2idx, device)
    input_file = "./val.json"
    output_file = "./submissionval.json"
    predict_and_save(model, input_file, output_file, char2idx, device)
    

2025-02-08 12:21:05,306 - INFO - Epoch 1/32, Loss: 0.1009
2025-02-08 12:21:10,908 - INFO - Epoch 2/32, Loss: 0.0511
2025-02-08 12:21:17,078 - INFO - Epoch 3/32, Loss: 0.0432
2025-02-08 12:21:20,718 - INFO - Epoch 4/32, Loss: 0.0387
2025-02-08 12:21:25,088 - INFO - Epoch 5/32, Loss: 0.0353
2025-02-08 12:21:31,856 - INFO - Epoch 6/32, Loss: 0.0328
2025-02-08 12:21:39,036 - INFO - Epoch 7/32, Loss: 0.0306
2025-02-08 12:21:45,932 - INFO - Epoch 8/32, Loss: 0.0291
2025-02-08 12:21:50,754 - INFO - Epoch 9/32, Loss: 0.0276
2025-02-08 12:21:56,369 - INFO - Epoch 10/32, Loss: 0.0264
2025-02-08 12:22:01,307 - INFO - Epoch 11/32, Loss: 0.0251
2025-02-08 12:22:07,456 - INFO - Epoch 12/32, Loss: 0.0243
2025-02-08 12:22:13,516 - INFO - Epoch 13/32, Loss: 0.0234
2025-02-08 12:22:19,374 - INFO - Epoch 14/32, Loss: 0.0224
2025-02-08 12:22:26,724 - INFO - Epoch 15/32, Loss: 0.0218
2025-02-08 12:22:34,355 - INFO - Epoch 16/32, Loss: 0.0210
2025-02-08 12:22:42,116 - INFO - Epoch 17/32, Loss: 0.0204
2025-0