In [None]:
!pip uninstall -y torchtext torchvision torchaudio torch
!pip install torch torchvision torchaudio torchtext --extra-index-url https://download.pytorch.org/whl/cu118


[0mFound existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu118
Collecting torch
  Downloading torch-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting torchvision
  Downloading torchvision-0.23.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading torchaudio-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (7.2 kB)
Collecting torchtext
  Downloading torchtext-0.18.0-cp311-cp311-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl

In [None]:
# ==========================================
# 1. Install dependencies
# ==========================================
!pip install tokenizers rouge-score torch --quiet

# ==========================================
# 2. Upload and extract dataset, flexible path
# ==========================================
from google.colab import files
import zipfile
import os

print("Upload your IN-Abs-small.zip file now.")
uploaded = files.upload()
zip_filename = list(uploaded.keys())[0]
print(f"Uploaded file: {zip_filename}")

# Extract to /content/IN-Abs-small
with zipfile.ZipFile(zip_filename, 'r') as z:
    z.extractall("/content/IN-Abs-small")
print("Extraction complete.")

# ==========================================
# 3. Resolve dataset root (nested IN-Abs-small)
# ==========================================
def find_data_root(base_path):
    # If there's a nested 'IN-Abs-small' inside, use that
    inner = os.path.join(base_path, "IN-Abs-small")
    if os.path.isdir(inner):
        return inner
    return base_path

base_extracted_path = "/content/IN-Abs-small"
dataset_root = find_data_root(base_extracted_path)
train_dir = os.path.join(dataset_root, "train-data")
test_dir  = os.path.join(dataset_root, "test-data")

print("Train data directory:", train_dir)
print("Test data directory: ", test_dir)

# ==========================================
# 4. Load judgement-summary pairs
# ==========================================
def load_judgement_summary_pairs(split_dir):
    judgement_dir = os.path.join(split_dir, "judgement")
    summary_dir = os.path.join(split_dir, "summary")
    file_list = sorted([name for name in os.listdir(judgement_dir) if name.endswith(".txt")])
    inputs, targets = [], []
    for fname in file_list:
        with open(os.path.join(judgement_dir, fname), encoding="utf-8") as fj, \
             open(os.path.join(summary_dir, fname), encoding="utf-8") as fs:
            inputs.append(fj.read().strip())
            targets.append(fs.read().strip())
    return inputs, targets

train_inputs, train_targets = load_judgement_summary_pairs(train_dir)
test_inputs, test_targets   = load_judgement_summary_pairs(test_dir)
print(f"Loaded {len(train_inputs)} training samples and {len(test_inputs)} test samples.")

# ==========================================
# 5. Tokenizers
# ==========================================
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, processors

def train_tokenizer(texts, vocab_size=32000):
    tokenizer = Tokenizer(models.WordLevel(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
    trainer = trainers.WordLevelTrainer(vocab_size=vocab_size, special_tokens=["[UNK]", "[PAD]", "[BOS]", "[EOS]"])
    tokenizer.train_from_iterator(texts, trainer=trainer)
    tokenizer.post_processor = processors.TemplateProcessing(
        single="[BOS] $A [EOS]",
        special_tokens=[
            ("[BOS]", tokenizer.token_to_id("[BOS]")),
            ("[EOS]", tokenizer.token_to_id("[EOS]"))
        ]
    )
    return tokenizer

input_tokenizer = train_tokenizer(train_inputs)
target_tokenizer = train_tokenizer(train_targets)

def encode_batch(tokenizer, texts):
    import torch
    return [torch.tensor(tokenizer.encode(t).ids) for t in texts]

train_input_seqs = encode_batch(input_tokenizer, train_inputs)
train_target_seqs = encode_batch(target_tokenizer, train_targets)
test_input_seqs  = encode_batch(input_tokenizer, test_inputs)
test_target_seqs = encode_batch(target_tokenizer, test_targets)

input_vocab_size = input_tokenizer.get_vocab_size()
target_vocab_size = target_tokenizer.get_vocab_size()
pad_id = target_tokenizer.token_to_id("[PAD]")

# ==========================================
# 6. DataLoader
# ==========================================
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch

class TextPairDataset(Dataset):
    def __init__(self, srcs, tgts):
        self.srcs = srcs
        self.tgts = tgts
    def __len__(self):
        return len(self.srcs)
    def __getitem__(self, idx):
        return self.srcs[idx], self.tgts[idx]

def collate_fn(batch):
    srcs, tgts = zip(*batch)
    srcs = pad_sequence(srcs, batch_first=True, padding_value=input_tokenizer.token_to_id("[PAD]"))
    tgts = pad_sequence(tgts, batch_first=True, padding_value=pad_id)
    return srcs, tgts

batch_size = 16
train_data = DataLoader(TextPairDataset(train_input_seqs, train_target_seqs), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# ==========================================
# 7. Liquid Time-Constant Model (Encoder, Decoder, LTC Cell)
# ==========================================
import torch.nn as nn

class LTCCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.W_in = nn.Linear(input_size, hidden_size)
        self.W_rec = nn.Linear(hidden_size, hidden_size)
        self.tau = nn.Parameter(torch.ones(hidden_size))
        self.nonlinearity = nn.Tanh()
    def forward(self, x, h):
        dx = (-h + self.nonlinearity(self.W_in(x) + self.W_rec(h))) / torch.relu(self.tau)
        return h + dx

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.ltc = LTCCell(embed_size, hidden_size)
    def forward(self, src):
        batch_size, seq_len = src.size()
        h = torch.zeros(batch_size, self.ltc.hidden_size, device=src.device)
        emb = self.embedding(src)
        for t in range(seq_len):
            h = self.ltc(emb[:, t, :], h)
        return h

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.ltc = LTCCell(embed_size, hidden_size)
        self.fc_out = nn.Linear(hidden_size, vocab_size)
    def forward(self, trg, hidden):
        batch_size, seq_len = trg.size()
        h = hidden
        emb = self.embedding(trg)
        outputs = []
        for t in range(seq_len):
            h = self.ltc(emb[:, t, :], h)
            outputs.append(self.fc_out(h).unsqueeze(1))
        return torch.cat(outputs, dim=1)

class Seq2SeqLTC(nn.Module):
    def __init__(self, input_vocab_size, target_vocab_size, embed_size=256, hidden_size=512):
        super().__init__()
        self.encoder = Encoder(input_vocab_size, embed_size, hidden_size)
        self.decoder = Decoder(target_vocab_size, embed_size, hidden_size)
    def forward(self, src, trg):
        h = self.encoder(src)
        outputs = self.decoder(trg, h)
        return outputs

# ==========================================
# 8. Training loop with percent progress
# ==========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2SeqLTC(input_vocab_size, target_vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)

def train_epoch(epoch, num_epochs):
    model.train()
    total_loss = 0
    num_batches = len(train_data)
    for i, (src_seqs, tgt_seqs) in enumerate(train_data):
        src_seqs = src_seqs.to(device)
        tgt_seqs = tgt_seqs.to(device)
        trg_in = tgt_seqs[:, :-1]
        trg_out = tgt_seqs[:, 1:]
        optimizer.zero_grad()
        output = model(src_seqs, trg_in)
        loss = criterion(output.reshape(-1, target_vocab_size), trg_out.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        percent = (i+1) / num_batches * 100
        print(f"\rEpoch {epoch}/{num_epochs} | {percent:.1f}% training done...", end='')
    avg_loss = total_loss / num_batches
    print(f"\rEpoch {epoch}/{num_epochs} | Training complete | Avg Loss: {avg_loss:.4f}")

for epoch in range(1, 6):
    train_epoch(epoch, 5)

# ==========================================
# 9. Greedy summary generation (preview)
# ==========================================
def generate_summary(src_seq, max_len=50):
    model.eval()
    src_seq = src_seq.unsqueeze(0).to(device)
    with torch.no_grad():
        h = model.encoder(src_seq)
        bos_id = target_tokenizer.token_to_id("[BOS]")
        eos_id = target_tokenizer.token_to_id("[EOS]")
        inp = torch.tensor([[bos_id]], device=device)
        outputs = []
        for _ in range(max_len):
            emb = model.decoder.embedding(inp[:, -1])
            h = model.decoder.ltc(emb, h)
            out_logits = model.decoder.fc_out(h)
            next_id = out_logits.argmax(dim=-1).item()
            if next_id == eos_id:
                break
            outputs.append(next_id)
            inp = torch.cat([inp, torch.tensor([[next_id]], device=device)], dim=1)
    return target_tokenizer.decode(outputs)

print("\n=== SAMPLE GENERATION ON TEST ===")
for i in range(min(3, len(test_input_seqs))):
    gen = generate_summary(test_input_seqs[i])
    print(f"\nTest sample {i+1}:")
    print("Generated:", gen)
    print("Reference:", test_targets[i])

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m104.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m71.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m47.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.4 MB/s[0m eta [36m

Saving IN-Abs-small.zip to IN-Abs-small.zip
Uploaded file: IN-Abs-small.zip
Extraction complete.
Train data directory: /content/IN-Abs-small/IN-Abs-small/train-data
Test data directory:  /content/IN-Abs-small/IN-Abs-small/test-data
Loaded 100 training samples and 20 test samples.
Epoch 1/5 | 14.3% training done...