In [1]:
import csv
from PIL import Image
import os
import json
import numpy as np
import torch
import random
import torch.nn as nn
import clip
from tqdm import tqdm
class CFG():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dataset_train_per = 0.8
    lr   = 1e-4
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 100
    batch_size = 24
    workers = 16

In [2]:

class Metaworld_Dataset:
    def __init__(self,json_file_dir,data_dir,images_transform,general_transform) -> None:
        self.csv_file = json_file_dir
        self.data_dir = data_dir
        self.images_transform = images_transform
        self.general_transform = general_transform
        with open(json_file_dir, "r") as read_file:
            data = json.load(read_file)
    


        self.data = data
        self.num_seqs = len(self.data.keys())
        self.seq_len = 200

        self.instructons = {'button-press-topdown-v2':['press the button']}
        self.max = 0


    def __len__(self):
        return (self.num_seqs * self.seq_len)
    
    def __getitem__(self,index):
        #idx = str(idx) 
        seq_num = index//self.seq_len
        idx     = (index - (seq_num*self.seq_len))

        data     = self.data[str(seq_num)]['data'] 
        taskname = self.data[str(seq_num)]['task_name']
        instruct = random.choice(self.instructons[taskname])

        ret = {}
       

        step        = data['step'][idx]
        prev_action = data['prev_action'][idx]
        action      = data['action'][idx]
        reward      = data['reward'][idx]
        state       = data['state'][idx]
    
        images_dir = os.path.join(self.data_dir,'images',taskname,str(seq_num))
        
        images_dirs =  [images_dir+'/'+str(step)+'_corner.png'        ,       
                        images_dir+'/'+str(step)+'_corner2.png'     , 
                        images_dir+'/'+str(step)+'_behindGripper.png',
                        images_dir+'/'+str(step)+'_corner3.png'     , 
                        images_dir+'/'+str(step)+'_topview.png'     
        ]
        step_images = []
        for i in range(len(images_dirs)):
            image = Image.open(images_dirs[i])
            image = self.images_transform(image)
            step_images.append(image)

        action = torch.tensor(action)
        #action[0:3] = (action[0:3]+5)/10
        action[action == -1] = 2


        ret['image']       = torch.stack(step_images)
        ret['action']      = action
        ret['state']       = torch.tensor(state)
        ret['prev_action'] = torch.tensor(prev_action)
        ret['reward']      = torch.tensor(reward)
        ret['caption']     = instruct
        
       
        return ret
    

def prepare_batch(batch):
    batch['image'] = batch['image'].permute(2,0,1,3,4,5)
    

    return batch


In [3]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


In [4]:
import math
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Info
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

class Transformer(nn.Module):
    """
    Model from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    # Constructor
    def __init__(
        self,
        dim_model,
        num_heads,
        num_encoder_layers,
        dropout_p,
        seq_length,
        emp_length,
        num_actions,
        variations_per_action
    ):
        super().__init__()

        # INFO
        self.model_type = "Transformer"
        self.dim_model = dim_model

        # LAYERS
        self.positional_encoder = PositionalEncoding(
            dim_model=dim_model, dropout_p=dropout_p, max_len=seq_length
        )
        #self.embedding = nn.Embedding(8, dim_model)
       
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model, nhead=num_heads,dropout=0.1)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        self.flatten = nn.Flatten(1)
        self.outs    = [nn.Linear(seq_length*emp_length, variations_per_action).to(CFG.device) for i in range(num_actions)]
        

    def forward(self, src):
        
        src = self.positional_encoder(src)
      
        src = src.permute(1,0,2)
        
        transformer_out = self.transformer(src)
        transformer_out = transformer_out.permute(1,0,2)
        transformer_out = self.flatten(transformer_out)
        
        rets = []
        for i in range(4):
            out = self.outs[i](transformer_out)
            rets.append(out)

        #rets = torch.cat(rets,1)

        return rets
    

In [5]:

class Policy(nn.Module):
    def __init__(self,language_text_model,policy_head,seq_length,emp_length):
        super().__init__()

        self.language_text_model = language_text_model
        self.policy_head = policy_head
        self.seq_length = seq_length
        self.emp_length = emp_length
        
        self.language_text_model.eval()
    def forward(self,batch):
        batch_size,cams,ch,h,w  = batch['image'].shape
        batch["image"] = torch.flatten(batch["image"], start_dim=0, end_dim=1)

        image_features = self.language_text_model.encode_image(batch["image"])
        text_features = self.language_text_model.encode_text(batch["caption"])
        
        image_features = torch.unflatten(image_features,dim = 0,sizes=(batch_size,cams))

        embeddings = torch.cat([image_features,text_features[:,None,:]],dim=1)
        embeddings = embeddings.flatten(1)
        embeddings = embeddings.unflatten(-1,(self.seq_length,self.emp_length)) # batch 192 , 8

        logits = self.policy_head(embeddings)

        return logits

In [6]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, step,criterion):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for batch in tqdm_object:
        batch['caption'] = clip.tokenize(batch['caption']).to(CFG.device)
        batch['image']   = batch['image'].to(CFG.device)
        batch['action'] = batch['action'].to(CFG.device)
        logits = model(batch)
       
        loss = 0
        for i in range(4):

            loss += criterion(logits[i],batch['action'][:,i].to(torch.int64))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step == "batch":
            lr_scheduler.step()

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter


def valid_epoch(model, valid_loader,criterion):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch in tqdm_object:
        batch['caption'] = clip.tokenize(batch['caption']).to(CFG.device)
        batch['image']   = batch['image'].to(CFG.device)
        batch['action'] = batch['action'].to(CFG.device)
        logits = model(batch)
       
        loss = 0
        for i in range(4):
            loss += criterion(logits[i],batch['action'][:,i].to(torch.int64))
        tqdm_object.set_postfix(valid_loss=loss.mean())
    return loss_meter

from torchvision import transforms

def main():

    policy_head = Transformer(
    dim_model=8,
    num_heads=2,
    num_encoder_layers=3,
    dropout_p=0.1,
    seq_length = 384,
    emp_length = 8,
    num_actions= 4,
    variations_per_action = 3
    ).to(CFG.device) 
    clip_model , preprocess = clip.load("ViT-B/32", device=CFG.device)
    
    policy = Policy(language_text_model=clip_model,
                    policy_head=policy_head,
                    seq_length=384,
                    emp_length=8,
                    )

    actions_transforms = transforms.Compose([transforms.ToTensor()])
    dataset = Metaworld_Dataset('/media/ahmed/HDD/WorkSpace/master_thesis/repo/datasets/metaworld/single_env.json',
                                '/media/ahmed/HDD/WorkSpace/master_thesis/repo/datasets/metaworld/',preprocess,actions_transforms)
    dataset_length = len(dataset)
    trainset_length = int(dataset_length * CFG.dataset_train_per)
    print('dataset len:',dataset_length,' trainset len:',trainset_length)
    train_set, val_valid = torch.utils.data.random_split(dataset, [trainset_length, dataset_length - trainset_length])

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=CFG.batch_size,shuffle=True, num_workers = CFG.workers)
    valid_loader = torch.utils.data.DataLoader(val_valid, batch_size=CFG.batch_size,shuffle=False,num_workers = CFG.workers)
    

    
    criterion = nn.CrossEntropyLoss()
    params = [
        {"params": policy.policy_head.parameters(), "lr": CFG.lr},
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=CFG.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
    )
    step = "epoch"

    best_loss = float('inf')
    for epoch in range(CFG.epochs):
        print(f"Epoch: {epoch + 1}")
        policy.policy_head.train()
        train_loss = train_epoch(policy, train_loader, optimizer, lr_scheduler, step,criterion)
        policy.eval()
        with torch.no_grad():
            valid_loss = valid_epoch(policy, valid_loader)
            if valid_loss.avg < best_loss:
                best_loss = valid_loss.avg
                torch.save(policy.policy_head.state_dict(), "best.pt")
                print("Saved Best Model!")
            
        lr_scheduler.step(valid_loss.avg)

In [7]:
main()

dataset len: 200000  trainset len: 160000
Epoch: 1


  0%|          | 24/6667 [01:09<5:20:12,  2.89s/it, lr=0.0001, train_loss=3.58]


KeyboardInterrupt: 

In [None]:
import torch
import clip
from PIL import Image
from torchvision import transforms

actions_transforms = transforms.Compose([transforms.ToTensor()])

model, preprocess = clip.load("ViT-B/32", device=CFG.device)
dataset = Metaworld_Dataset('/media/ahmed/HDD/WorkSpace/master_thesis/repo/datasets/metaworld/single_env.json','/media/ahmed/HDD/WorkSpace/master_thesis/repo/datasets/metaworld/',preprocess,actions_transforms)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=8,shuffle=True)
model.eval()
for batch in train_loader:
    batch_size,cams,ch,h,w  = batch['image'].shape
    print(batch['image'].shape)
    batch["image"] = torch.flatten(batch["image"], start_dim=0, end_dim=1)

    text = clip.tokenize(batch['caption']).to(device)

    image_features = model.encode_image(batch["image"].to(device))
    text_features = model.encode_text(text)
    
    image_features = torch.unflatten(image_features,dim = 0,sizes=(batch_size,cams))

    print(image_features.shape,text_features.shape)


torch.Size([8, 5, 3, 224, 224])


NameError: name 'device' is not defined