In [3]:
#Импортируем все необходимое
import pandas as pd
import torch 
from torch import nn
import numpy as np
import random
from collections import defaultdict
from torch.utils.data import Dataset, dataloader
import gymnasium as gym
import wandb

Вспомогательные функции для SimRCRL-лосса

In [4]:
def indicator(i,j):
    if i==j:
        return 0
    else:
        return 1
    
    
def denominator(i, batch_ah, batch_p, tau):
    value=0
    for j in range(len(batch_ah)):
        dot1=torch.exp(torch.dot(batch_ah[i], batch_ah[j])/tau)
        dot2=torch.exp(torch.dot(batch_ah[i], batch_p[j])/tau)
        
        value+=indicator(i, j)*(dot1 + dot2)
    return value


def RCRL_loss(batch_ah, batch_p):
    tau=.1
    res=0
    
    for i in range(len(batch_ah)):
        dot=torch.exp(torch.dot(batch_ah[i], batch_p[i])/tau)
        value=denominator(i, batch_ah, batch_p, tau)
        res+= (-1)*torch.log(dot/value)
        
    return res

    

Реализация bucket-sampling для построения эмбеддингов

In [5]:
def get_separated_batch(batch, g_list, L):
    bk_list=np.linspace(min(g_list), max(g_list), L)
    
    separated_batch=defaultdict(list)
    positive_batch=[]
    
    
    for j in range(len(g_list)):
        for i in range(len(bk_list)-1):
            if (g_list[j]>=bk_list[i] and g_list[j]<=bk_list[i+1]):
                separated_batch[bk_list[i]].append(batch[j])
                
    for j in range(len(g_list)):
        for i in range(len(bk_list)-1):
            if (g_list[j]>=bk_list[i] and g_list[j]<=bk_list[i+1]):
                positive_batch.append(random.choice(separated_batch[bk_list[i]]))
    
    return positive_batch


Класс для построения contrastive embeddings и positive/negative tokens sampling

In [6]:
class contrastive_embeddings(nn.Module):
    
    def __init__(self, embedding_size, sequence_length, state_dim, act_dim, buckets_num):
        super().__init__()
        
        self.sequence_length=sequence_length
        self.state_dim=state_dim
        self.act_dim=act_dim
        self.buckets_num=buckets_num
        self.embedding_size=embedding_size
        
        self.states_embed=nn.Linear(self.state_dim, self.embedding_size, bias=True)
        self.actions_embed=nn.Linear(self.act_dim, self.embedding_size, bias=True)
        self.rewards_embed=nn.Linear(1, self.embedding_size, bias=True)
        self.timesteps_embed=nn.Embedding(1000, self.embedding_size)
        
    def forward(self, states, actions, rewards, return_to_go=0,  
                attentions_mask=0, train=True):
        if train:
            states_embeddings=self.states_embed(states)
            actions_embeddings=self.actions_embed(actions)
            rewards_embeddings=self.rewards_embed(rewards.reshape((10,1)))
            #timesteps_embeddings=self.timesteps_embed(states)
            
            z_ah_embeddings=np.multiply(states_embeddings.detach(), actions_embeddings.detach())
            positives=get_separated_batch(z_ah_embeddings, rewards, self.buckets_num)
            positives=torch.stack(positives)
            positives=torch.tensor(positives, requires_grad=True)
            z_ah_embeddings=torch.tensor(z_ah_embeddings, requires_grad=True)
            
            
            return states_embeddings, actions_embeddings, z_ah_embeddings, positives
        
        
        else:
            states_embeddings=self.states_embed(states)
            actions_embeddings=self.actions_embed(actions)
            #timesteps_embeddings=self.timesteps_embed(states)
            
            return states_embeddings, actions_embeddings
            
    



In [7]:
class dataset(Dataset):
    
    def __init__(self, data, transforms=None):
        self.data=data
        self.transforms=transoforms
        
    def __len__(self):
        return len(data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.data[idx]

Блок для самого Decision Transformer

In [8]:
class TransformerBlock(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads, ff_size, dropout=0.1):
        super().__init__()
        self.query_emb=nn.Linear(input_size, hidden_size)
        self.key_emb=nn.Linear(input_size, hidden_size)
        self.value_emb=nn.Linear(input_size, hidden_size)
        
        self.attention = torch.nn.MultiheadAttention(hidden_size, num_heads, dropout)
        self.ffn = nn.Linear(hidden_size, ff_size, dropout)

    def forward(self, x):
        query=self.query_emb(x)
        key=self.key_emb(x)
        value=self.value_emb(x)
        
        x = self.attention(query, key, value)
        
        return x

In [9]:
class GPT_2(nn.Module):
    
    def __init__(self, num_blocks, input_size, hidden_size, num_heads, ff_size, dropout=0.1):
        super().__init__()
        self.single_layer=TransformerBlock(input_size, hidden_size, num_heads, ff_size, dropout=0.1)
        layers=[self.single_layer for i in range(num_blocks)]
        self.layers=nn.ModuleList(layers)
        self.model=nn.Sequential(self.layers)
        
    def forward(self, x):
        output=torch.zeros(x.shape)
        
        for layer in self.layers:
            output+=layer(x)[0]
            
            
        return output
        

In [10]:
class ContrastiveDT(nn.Module):
    
    def __init__(self,embedding_size, sequence_length, state_dim, act_dim, buckets_num,
                num_blocks, input_size, hidden_size, num_heads, ff_size, dropout=0.1):
        super().__init__()
        self.embedding_size=embedding_size
        self.sequence_length=sequence_length
        self.state_dim=state_dim
        self.act_dim=act_dim
        self.buckets_num=buckets_num
        self.num_blocks=num_blocks
        self.input_size=input_size
        self.hidden_size=hidden_size
        self.num_heads=num_heads 
        self.ff_size=ff_size
        self.dropout=dropout
        
        self.embedding_layer=contrastive_embeddings(self.embedding_size, self.sequence_length, 
                                                      self.state_dim, self.act_dim, self.buckets_num)
        
        self.gpt_layer=GPT_2(self.num_blocks, self.embedding_size, self.hidden_size, 
                             self.num_heads, self.ff_size, self.dropout)
        
        
        
        dim=2*self.embedding_size*self.sequence_length
        self.output_states=nn.Linear(dim, self.sequence_length*self.state_dim)
        self.output_actions=nn.Linear(dim, self.sequence_length*self.act_dim)
        self.output_rewards=nn.Linear(dim, self.sequence_length)
        
    def forward(self, states, actions, rewards, train=True):
        embeddings=self.embedding_layer(states, actions, rewards)
        states_embeddings=list(embeddings[0].detach().numpy()[::-1])
        action_embeddings=list(embeddings[1].detach().numpy()[::-1])
        
        gpt_input_embeddings=[]
        for i in range(len(states_embeddings)*2):
            
            if i%2==0:
                gpt_input_embeddings.append(torch.tensor(states_embeddings.pop()))
            else:
                gpt_input_embeddings.append(torch.tensor(action_embeddings.pop()))
                
        gpt_input_embeddings=torch.stack(gpt_input_embeddings)
        
        out=self.gpt_layer(gpt_input_embeddings)
        
        out=out.reshape((2*self.embedding_size*self.sequence_length,))
        
        out_states=self.output_states(out).reshape((self.sequence_length,self.state_dim))
        out_actions=self.output_actions(out).reshape((self.sequence_length,self.act_dim))
        out_rewards=self.output_rewards(out).reshape((self.sequence_length))
        
        out=list(zip(out_states, out_actions, out_rewards))
        #out=torch.stack(out)
        
        #Output - тройки (s, a, g), тензор состояний, тензор действий, тензор rewards и positive/negative

        
        return out, out_states, out_actions, out_rewards, embeddings[2], embeddings[3]
        
        
        
        
        
        
        

In [12]:
model=ContrastiveDT(embedding_size=128, sequence_length=10, state_dim=11, act_dim=3, buckets_num=5,
                num_blocks=12, input_size=128, hidden_size=128, num_heads=8, ff_size=128, dropout=0.1)

Блок подготовки данных для экспериментов

In [30]:
def prepare_sample(path, seed_value, action_dim=0):
    env=gym.make(path)
    observation, info = env.reset(seed=seed_value)
    
    action_num = env.action_space.sample()  # this is where you would insert your policy
    observation, reward, terminated, truncated, info = env.step(action_num)
    action=torch.zeros((action_dim))
    action[action_num]=1
    
    
    if terminated or truncated:
        observation, info = env.reset()
    env.close()
    
    observation = observation.flatten()
    
    return observation, action, reward
    
    
    

In [60]:
def collect_dataset(path, length, action_dim=0):
    states_data=[]
    actions_data=[]
    rewards_data=[]
    for i in range(length):
        observation, action, reward=prepare_sample(path, i, action_dim)
        states_data.append(torch.tensor(observation, dtype=torch.float64, requires_grad=True))
        actions_data.append(torch.tensor(action, requires_grad=True))
        rewards_data.append(torch.tensor(reward, requires_grad=True))
        
    return torch.stack(states_data), torch.stack(actions_data), torch.stack(rewards_data)
        
        
        