# A2T GPT Model for Travel Domain Transfer Learning

This notebook implements a GPT-based model with the A2T (Attend, Adapt, and Transfer) framework for transfer learning from a generic dataset (Tiny Shakespeare) to a travel domain. It incorporates reinforcement learning algorithms: REINFORCE, Actor-Critic, and Q-learning with DQN. The model generates travel itineraries by adapting knowledge from pre-trained source models.

## Setup Instructions
- Run in Google Colab with a GPU for faster training.
- Install required libraries: `!pip install torch transformers tqdm`.
- The notebook downloads `input.txt` (Tiny Shakespeare) and creates a sample `travel.txt`. Replace `travel.txt` with a larger dataset for better results.
- Choose an RL algorithm (`train_reinforce`, `train_actor_critic`, or `train_dqn`) to train the model.
- After training, generate travel itineraries using the prompt 'Plan a trip to Paris'.

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import deque
import random
import numpy as np
from transformers import pipeline
from tqdm import tqdm

# Hyperparameters
batch_size = 16
block_size = 32
max_iters = 2000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
vocab_size = None
n_source_tasks = 2
replay_buffer_size = 10000
epsilon = 0.1

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import requests

# Download Tiny Shakespeare dataset
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)
shakespeare_text = response.text

# Save Tiny Shakespeare dataset to input.txt
with open('input.txt', 'w', encoding='utf-8') as f:
    f.write(shakespeare_text)

# Larger travel dataset
travel_text = """Explore Paris, France:\nDay 1: Visit the Eiffel Tower, enjoy a Seine River cruise, and dine at a charming café in Montmartre. Don't miss the Sacré-Cœur Basilica for stunning city views.\nDay 2: Explore the Louvre Museum, home to the Mona Lisa and thousands of artworks. Stroll along the Champs-Élysées and visit the Arc de Triomphe.\nDay 3: Take a day trip to Versailles Palace, known for its opulent gardens and Hall of Mirrors. Return to Paris for a cozy dinner in Le Marais.\n\nDiscover Tokyo, Japan:\nDay 1: Experience the vibrant Shibuya Crossing and visit the Meiji Shrine. Enjoy sushi at a local restaurant in Ginza.\nDay 2: Explore Asakusa and the historic Senso-ji Temple. Take a trip to Tokyo Skytree for panoramic views of the city.\nDay 3: Visit Akihabara for electronics and anime culture, then relax in Ueno Park and explore its museums.\n\nPlan a Trip to New York City, NY, USA:\nDay 1: Start at Times Square, visit the Empire State Building, and catch a Broadway show in the evening.\nDay 2: Walk through Central Park, visit the Metropolitan Museum of Art, and take a ferry to the Statue of Liberty.\nDay 3: Explore Brooklyn, including the Brooklyn Bridge and DUMBO neighborhood. Enjoy pizza at a local pizzeria.\n\nTravel Tips for Europe:\n- Book train tickets in advance for Eurail passes to save money.\n- Pack light but include comfortable walking shoes for cobblestone streets.\n- Learn basic phrases in the local language to enhance your experience.\n- Always carry a reusable water bottle to stay hydrated.\n\nTravel Tips for Asia:\n- Respect local customs, such as removing shoes before entering temples.\n- Try street food but ensure it’s from reputable vendors to avoid foodborne illnesses.\n- Use apps like Google Translate for real-time translation in non-English-speaking countries.\n\nTop Beaches in the World:\n- Maldives: Crystal-clear waters and overwater bungalows make it a paradise.\n- Bora Bora, French Polynesia: Stunning lagoons and luxurious resorts.\n- Santorini, Greece: Black sand beaches and breathtaking caldera views.\n\nAdventure Travel in South America:\n- Hike the Inca Trail to Machu Picchu, Peru, for a bucket-list experience.\n- Explore the Amazon Rainforest in Brazil for unique wildlife encounters.\n- Visit Patagonia in Chile for glacier trekking and stunning landscapes.\n\nCruise Itineraries:\n- Mediterranean Cruise: Visit Barcelona, Rome, and Athens for a mix of history and culture.\n- Caribbean Cruise: Stop at Jamaica, the Bahamas, and Cozumel for sun and fun.\n- Alaskan Cruise: See glaciers, whales, and fjords in a pristine natural setting."""
with open('travel.txt', 'w', encoding='utf-8') as f:
    f.write(travel_text)

# Load datasets
with open('input.txt', 'r', encoding='utf-8') as f:
    generic_text = f.read()
with open('travel.txt', 'r', encoding='utf-8') as f:
    travel_text = f.read()

# Create unified vocabulary
chars = sorted(list(set(generic_text + travel_text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Encode datasets
generic_data = torch.tensor(encode(generic_text), dtype=torch.long)
travel_data = torch.tensor(encode(travel_text), dtype=torch.long)

# Split data
n = int(0.9 * len(travel_data))
train_data = travel_data[:n]
val_data = travel_data[n:]

In [4]:
# Data loading
def get_batch(split, data_type='travel'):
    data = train_data if split == 'train' else val_data if data_type == 'travel' else generic_data
    if len(data) < block_size + 1:
        raise ValueError(f'Dataset too small for block_size={block_size}. Reduce block_size or expand dataset.')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]).to(device)
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(device)
    return x, y

# Loss estimation
@torch.no_grad()
def estimate_loss(model, data_type='travel'):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            try:
                X, Y = get_batch(split, data_type)
                logits, loss = model(X, Y)
                losses[k] = loss.item()
            except ValueError:
                losses[k] = float('inf')
        out[split] = losses.mean()
    model.train()
    return out

In [5]:
# Cleanse action function
def cleanse_action(action_text):
    return action_text.strip()

# Simulated travel environment
class TravelEnvironment:
    def __init__(self):
        self.sentiment_pipe = pipeline('sentiment-analysis', model='distilbert-base-uncased-finetuned-sst-2-english', device=0 if device == 'cuda' else -1)
    
    def step(self, state, action):
        state_text = decode(state.tolist())
        action_text = cleanse_action(decode([action.item()]))
        response = state_text + action_text
        reward = self.sentiment_pipe(response)[0]['score'] if self.sentiment_pipe(response)[0]['label'] == 'POSITIVE' else -self.sentiment_pipe(response)[0]['score']
        next_state = torch.cat((state[1:], torch.tensor([action], device=device)), dim=0)
        done = len(state_text) >= block_size
        return next_state, reward, done

# Replay buffer for DQN
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return torch.stack(state), torch.stack(action), torch.tensor(reward, device=device), torch.stack(next_state), torch.tensor(done, device=device)

In [6]:
# Attention Network
class AttentionNetwork(nn.Module):
    def __init__(self, n_embd, n_source_tasks):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 128),
            nn.ReLU(),
            nn.Linear(128, n_source_tasks + 1),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x):
        return self.net(x)

# Transformer Components
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
    
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [7]:
# Bigram Language Model for source tasks
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        tok_emb = self.token_embedding_table(idx)
        logits = self.lm_head(tok_emb)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

# A2T Model
class A2TModel(nn.Module):
    def __init__(self, vocab_size, n_source_tasks):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.q_head = nn.Linear(n_embd, vocab_size)
        self.attention_network = AttentionNetwork(n_embd, n_source_tasks)
        self.source_models = nn.ModuleList([BigramLanguageModel() for _ in range(n_source_tasks)])
    
    def forward(self, idx, targets=None, mode='policy'):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        
        if mode == 'policy':
            base_logits = self.lm_head(x)
        else:
            base_logits = self.q_head(x)
        
        source_logits = [model(idx)[0] for model in self.source_models]
        attention_input = x[:, -1, :]
        weights = self.attention_network(attention_input)
        combined_logits = weights[:, -1:].unsqueeze(-1) * base_logits
        for i in range(n_source_tasks):
            combined_logits += weights[:, i:i+1].unsqueeze(-1) * source_logits[i]
        
        if targets is None:
            loss = None
        else:
            B, T, C = combined_logits.shape
            logits = combined_logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return combined_logits, loss, weights

    def generate(self, idx, max_new_tokens, mode='policy'):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _, _ = self(idx_cond, mode=mode)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [8]:
# Initialize environment and replay buffer
env = TravelEnvironment()
replay_buffer = ReplayBuffer(replay_buffer_size)

# Initialize model and optimizers
model = A2TModel(vocab_size, n_source_tasks).to(device)
optimizer_policy = torch.optim.AdamW(list(model.lm_head.parameters()) + list(model.attention_network.parameters()), lr=learning_rate)
optimizer_value = torch.optim.RMSprop(list(model.q_head.parameters()) + list(model.attention_network.parameters()), lr=learning_rate)

# Pre-train source models on generic dataset
for i, source_model in enumerate(model.source_models):
    source_optimizer = torch.optim.AdamW(source_model.parameters(), lr=learning_rate)
    for iter in range(1000):
        xb, yb = get_batch('train', data_type='generic')
        logits, loss = source_model(xb, yb)
        source_optimizer.zero_grad(set_to_none=True)
        loss.backward()
        source_optimizer.step()
    print(f'Source model {i+1} trained, final loss: {loss.item():.4f}')

# Training loop for A2T with REINFORCE
def train_reinforce():
    model.train()
    for iter in tqdm(range(max_iters)):
        try:
            xb, yb = get_batch('train')
            state = xb[:, :-1]
            action = xb[:, -1]
            
            logits, loss, weights = model(state, mode='policy')
            probs = F.softmax(logits[:, -1, :], dim=-1)
            action_dist = torch.distributions.Categorical(probs)
            sampled_action = action_dist.sample()
            next_state, reward, done = env.step(state[0], sampled_action[0])
            
            log_prob = action_dist.log_prob(sampled_action)
            policy_loss = -log_prob * reward
            optimizer_policy.zero_grad()
            policy_loss.mean().backward()
            optimizer_policy.step()
            
            if iter % eval_interval == 0 or iter == max_iters - 1:
                losses = estimate_loss(model)
                print(f'step {iter}: train loss {losses["train"]:.4f}, val loss {losses["val"]:.4f}')
        except ValueError as e:
            print(f'Error at iteration {iter}: {e}')
            break

# Training loop for A2T with Actor-Critic
def train_actor_critic():
    model.train()
    critic = nn.Sequential(
        nn.Linear(n_embd, 128),
        nn.ReLU(),
        nn.Linear(128, 1)
    ).to(device)
    critic_optimizer = torch.optim.AdamW(critic.parameters(), lr=learning_rate)
    
    for iter in tqdm(range(max_iters)):
        try:
            xb, yb = get_batch('train')
            state = xb[:, :-1]
            logits, _, weights = model(state, mode='policy')
            probs = F.softmax(logits[:, -1, :], dim=-1)
            action_dist = torch.distributions.Categorical(probs)
            action = action_dist.sample()
            next_state, reward, done = env.step(state[0], action[0])
            
            state_emb = model.blocks(model.token_embedding_table(state) + model.position_embedding_table(torch.arange(state.shape[1], device=device)))
            value = critic(state_emb[:, -1, :])
            next_state_emb = model.blocks(model.token_embedding_table(next_state.unsqueeze(0)) + model.position_embedding_table(torch.arange(next_state.shape[0], device=device)))
            next_value = critic(next_state_emb[:, -1, :])
            
            delta = reward + (0.99 * next_value * (1 - done)) - value
            critic_loss = delta ** 2
            critic_optimizer.zero_grad()
            critic_loss.backward()
            critic_optimizer.step()
            
            log_prob = action_dist.log_prob(action)
            actor_loss = -log_prob * delta.detach()
            optimizer_policy.zero_grad()
            actor_loss.mean().backward()
            optimizer_policy.step()
            
            if iter % eval_interval == 0 or iter == max_iters - 1:
                losses = estimate_loss(model)
                print(f'step {iter}: train loss {losses["train"]:.4f}, val loss {losses["val"]:.4f}')
        except ValueError as e:
            print(f'Error at iteration {iter}: {e}')
            break

# Training loop for A2T with Q-learning (DQN)
def train_dqn():
    model.train()
    target_model = A2TModel(vocab_size, n_source_tasks).to(device)
    target_model.load_state_dict(model.state_dict())
    replay_buffer.clear()
    
    for iter in tqdm(range(max_iters)):
        try:
            xb, yb = get_batch('train')
            state = xb[:, :-1]
            if random.random() < epsilon:
                action = torch.randint(0, vocab_size, (batch_size,), device=device)
            else:
                q_values, _, _ = model(state, mode='value')
                action = q_values[:, -1, :].argmax(dim=-1)
            next_state, reward, done = env.step(state[0], action[0])
            replay_buffer.push(state[0], action[0], reward, next_state, done)
            
            if len(replay_buffer.buffer) >= batch_size:
                states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
                q_values, _, _ = model(states, mode='value')
                q_values = q_values.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
                with torch.no_grad():
                    next_q_values, _, _ = target_model(next_states, mode='value')
                    target_q = rewards + (1 - dones) * 0.99 * next_q_values.max(dim=-1)[0]
                loss = F.mse_loss(q_values, target_q)
                optimizer_value.zero_grad()
                loss.backward()
                optimizer_value.step()
            
            if iter % 100 == 0:
                target_model.load_state_dict(model.state_dict())
            
            if iter % eval_interval == 0 or iter == max_iters - 1:
                losses = estimate_loss(model)
                print(f'step {iter}: train loss {losses["train"]:.4f}, val loss {losses["val"]:.4f}')
        except ValueError as e:
            print(f'Error at iteration {iter}: {e}')
            break

Device set to use cpu


Source model 1 trained, final loss: 2.3680
Source model 2 trained, final loss: 2.3633


In [9]:
# Train with REINFORCE
train_reinforce()

# Generate travel itinerary
context = torch.tensor(encode('Plan a trip to Paris'), dtype=torch.long, device=device).unsqueeze(0)
generated = model.generate(context, max_new_tokens=100, mode='policy')
print(decode(generated[0].tolist()))

  0%|          | 3/2000 [00:03<34:57,  1.05s/it]  

step 0: train loss inf, val loss inf


  5%|▌         | 103/2000 [00:15<13:33,  2.33it/s]

step 100: train loss inf, val loss inf


 10%|█         | 203/2000 [00:26<12:50,  2.33it/s]

step 200: train loss inf, val loss inf


 15%|█▌        | 303/2000 [00:38<12:18,  2.30it/s]

step 300: train loss inf, val loss inf


 20%|██        | 403/2000 [00:48<10:53,  2.44it/s]

step 400: train loss inf, val loss inf


 25%|██▌       | 503/2000 [00:59<10:25,  2.39it/s]

step 500: train loss inf, val loss inf


 30%|███       | 603/2000 [01:10<09:47,  2.38it/s]

step 600: train loss inf, val loss inf


 35%|███▌      | 703/2000 [01:21<08:43,  2.48it/s]

step 700: train loss inf, val loss inf


 40%|████      | 803/2000 [01:32<08:18,  2.40it/s]

step 800: train loss inf, val loss inf


 45%|████▌     | 903/2000 [01:43<07:29,  2.44it/s]

step 900: train loss inf, val loss inf


 50%|█████     | 1003/2000 [01:54<06:55,  2.40it/s]

step 1000: train loss inf, val loss inf


 55%|█████▌    | 1103/2000 [02:05<06:08,  2.43it/s]

step 1100: train loss inf, val loss inf


 60%|██████    | 1203/2000 [02:15<05:28,  2.43it/s]

step 1200: train loss inf, val loss inf


 65%|██████▌   | 1303/2000 [02:26<04:36,  2.52it/s]

step 1300: train loss inf, val loss inf


 70%|███████   | 1403/2000 [02:37<04:00,  2.48it/s]

step 1400: train loss inf, val loss inf


 75%|███████▌  | 1503/2000 [02:48<03:33,  2.32it/s]

step 1500: train loss inf, val loss inf


 80%|████████  | 1603/2000 [02:59<02:45,  2.40it/s]

step 1600: train loss inf, val loss inf


 85%|████████▌ | 1703/2000 [03:09<01:58,  2.50it/s]

step 1700: train loss inf, val loss inf


 90%|█████████ | 1803/2000 [03:20<01:22,  2.38it/s]

step 1800: train loss inf, val loss inf


 95%|█████████▌| 1903/2000 [03:31<00:39,  2.45it/s]

step 1900: train loss inf, val loss inf


100%|██████████| 2000/2000 [03:42<00:00,  8.98it/s]

step 1999: train loss inf, val loss inf





Plan a trip to Parise verecay chexpfoyndway rixppsiaon UMef Enth Ps gas.
Éalbu, Brky Bracrthy Lit me forren theavintr Pa


## Notes
- **Dataset**: The sample `travel.txt` is ~1,200 characters. For better results, use a larger dataset (e.g., 100,000+ characters from travel guides or blogs).
- **Reward Function**: The sentiment-based reward can be replaced with a travel-specific metric (e.g., keyword matching for 'Eiffel Tower'). Example:
  ```python
  def compute_travel_reward(response):
      keywords = ['Eiffel Tower', 'Louvre', 'Paris', 'museum', 'café']
      return sum(1 for keyword in keywords if keyword.lower() in response.lower())
  ```
- **Tokenizer**: Uses character-level tokenization. For subword tokenization, integrate `GPT2Tokenizer`:
  ```python
  from transformers import GPT2Tokenizer
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
  encode = lambda s: tokenizer.encode(s, return_tensors='pt').squeeze().tolist()
  decode = lambda l: tokenizer.decode(l)
  vocab_size = tokenizer.vocab_size
  ```
- **Hyperparameters**: Adjust `max_iters` (e.g., to 5000) or `block_size` for larger datasets.
- **Visualization**: To inspect attention weights, add:
  ```python
  import matplotlib.pyplot as plt
  weights = model(context, mode='policy')[2]
  plt.imshow(weights.cpu().detach().numpy(), cmap='hot')
  plt.title('Attention Weights')
  plt.show()
  ```
- If errors occur, reduce `block_size` or expand `travel.txt`.