In [1]:
# ========================================
# VERSIONE SICURA - CON CHECKPOINT
# ========================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import os

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# ========================================
# MOUNT GOOGLE DRIVE (ESSENZIALE!)
# ========================================
from google.colab import drive
drive.mount('/content/drive')

# Directory checkpoint
checkpoint_dir = '/content/drive/MyDrive/shakespeare_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"‚úÖ Checkpoint directory: {checkpoint_dir}")

# ========================================
# HYPERPARAMETERS
# ========================================
vocab_size = 256
d_model = 128
num_heads = 8
num_layers = 6
d_ff = 512
max_seq_len = 128
dropout = 0.1

batch_size = 32
learning_rate = 3e-4
num_epochs = 6  # ‚Üê RIDOTTO a 6 per sicurezza
warmup_steps = 100

temperature = 0.8
top_k = 40

# ========================================
# DATASET
# ========================================
class TextDataset(Dataset):
    def __init__(self, text, seq_len):
        self.data = torch.tensor([ord(c) for c in text], dtype=torch.long)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        x = self.data[idx:idx+self.seq_len]
        y = self.data[idx+1:idx+self.seq_len+1]
        return x, y

# Scarica dataset
import requests
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
text = requests.get(url).text
print(f"Dataset: {len(text)} characters")

# Split
split = int(0.9 * len(text))
train_text, val_text = text[:split], text[split:]

train_dataset = TextDataset(train_text, max_seq_len)
val_dataset = TextDataset(val_text, max_seq_len)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# ========================================
# POSITIONAL ENCODING
# ========================================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# ========================================
# MULTI-HEAD ATTENTION
# ========================================
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, V)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        out = self.W_o(out)

        return out

# ========================================
# TRANSFORMER BLOCK
# ========================================
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )

        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        x = x + self.dropout(self.attn(self.ln1(x), mask))
        x = x + self.dropout(self.ff(self.ln2(x)))
        return x

# ========================================
# GPT MODEL
# ========================================
class GPT(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff,
                 max_seq_len, dropout=0.1):
        super().__init__()

        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)

        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        self.dropout = nn.Dropout(dropout)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        batch_size, seq_len = x.shape

        mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).unsqueeze(0).unsqueeze(0)

        x = self.token_embedding(x)
        x = self.pos_encoding(x)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, mask)

        x = self.ln_f(x)
        logits = self.head(x)

        return logits

# ========================================
# TRAINING FUNCTIONS
# ========================================
def train_epoch(model, loader, optimizer, scheduler, device):
    model.train()
    total_loss = 0

    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)

        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}")

    return total_loss / len(loader)

def validate(model, loader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
            total_loss += loss.item()

    return total_loss / len(loader)

@torch.no_grad()
def generate(model, prompt, max_new_tokens=100, temperature=1.0, top_k=None):
    model.eval()

    tokens = torch.tensor([ord(c) for c in prompt], dtype=torch.long).unsqueeze(0).to(device)

    for _ in range(max_new_tokens):
        tokens_cond = tokens if tokens.size(1) <= max_seq_len else tokens[:, -max_seq_len:]

        logits = model(tokens_cond)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        tokens = torch.cat([tokens, next_token], dim=1)

    generated = ''.join([chr(int(t)) for t in tokens[0].cpu().numpy()])
    return generated

# ========================================
# MAIN TRAINING - CON CHECKPOINT SYSTEM
# ========================================
model = GPT(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95))
total_steps = len(train_loader) * num_epochs
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=learning_rate/10)

# History
history = {'train_loss': [], 'val_loss': []}
best_val_loss = float('inf')

# Training loop
print("\nüöÄ Starting training...")
for epoch in range(num_epochs):
    print(f"\n{'='*50}")
    print(f"Epoch {epoch+1}/{num_epochs}")
    print('='*50)

    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
    val_loss = validate(model, val_loader, device)

    print(f"\nüìä Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)

    # üíæ SALVA CHECKPOINT OGNI EPOCA
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'history': history
    }

    # Salva checkpoint epoch
    checkpoint_path = f'{checkpoint_dir}/checkpoint_epoch_{epoch+1}.pth'
    torch.save(checkpoint, checkpoint_path)
    print(f"üíæ Checkpoint salvato: checkpoint_epoch_{epoch+1}.pth")

    # Salva best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_path = f'{checkpoint_dir}/best_model.pth'
        torch.save(checkpoint, best_path)
        print(f"‚≠ê BEST MODEL! Val Loss: {val_loss:.4f}")

    # Generate sample
    if (epoch + 1) % 2 == 0:
        print(f"\nüé≠ Generated sample:")
        prompt = "ROMEO:"
        generated = generate(model, prompt, max_new_tokens=200, temperature=temperature, top_k=top_k)
        print(generated[:400])

print("\n‚úÖ Training completato!")

# Test finale
print("\n" + "="*50)
print("üé≠ GENERAZIONE FINALE")
print("="*50)

prompts = ["ROMEO:", "JULIET:", "KING LEAR:"]
for prompt in prompts:
    print(f"\nüé¨ Prompt: {prompt}")
    generated = generate(model, prompt, max_new_tokens=150, temperature=0.8, top_k=40)
    print(generated[:300])
    print("-"*50)

# Salva modello finale
final_path = f'{checkpoint_dir}/shakespeare_gpt_final.pth'
torch.save(model.state_dict(), final_path)
print(f"\nüíæ Modello finale salvato: shakespeare_gpt_final.pth")

# Scarica
from google.colab import files
print("\nüì• Vuoi scaricare il modello? Esegui:")
print("files.download(f'{checkpoint_dir}/best_model.pth')")


Using device: cuda


MessageError: Error: credential propagation was unsuccessful