In [None]:
from typing import List

import numpy as np
import pandas as pd
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision import models
from torchvision.transforms import transforms
from dataclasses import dataclass
from tqdm import tqdm

In [None]:
@dataclass
class State:
    frame: torch.Tensor
    sp: float       # stamina points
    zoom: int

@dataclass
class Action:
    direction: int
    time_steps: int

# Check if frame has stone
def check_frame(self, frame) -> bool:
    crops = []
    for i in range(4):
        for j in range(4):
            crop = torchvision.transforms.functional.crop(frame, i * 24, j * 24, 24, 24)[None]
            crops.append(crop)
    expected_stone = self.stone_classifier(torch.cat(crops)).max(dim=0)[0][1].item()
    return expected_stone > 0.5

def find_state_with_stone(df: pd.DataFrame, max_attempt: int = 10) -> State:
    while True:
        index = np.random.randint(0, len(df) - 1)
        frame = load_image(df["video"][index], df["frame"][index])
        train_aug = A.Compose([A.Normalize(mean=(0.5,), std=(0.5,)),
                               ToTensorV2(transpose_mask=False),
                               ])
        frame = train_aug(image=frame)['image']
        if env.check_frame(frame):
            break
    sp = df["sp"][index]
    zoom = df["zoom"][index]
    state = State(frame, sp, zoom)
    return state

def see_plot(pict, size=(6, 6), title: str = None):
    plt.figure(figsize=size)
    plt.imshow(pict, cmap='gray')
    if title is not None:
        plt.title(title)
    plt.show()

def load_image(video, frame):
    path = '../surviv_rl_data/all_videoframes_rgb_96/{}/'.format(video)
    p = cv2.imread(path + 'f_{}.jpg'.format(frame))
    return p[:,:,::-1]

In [None]:
def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )


class ResNetUNet_v2(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)

        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3])  # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5])  # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
        self.dropout = nn.Dropout(0.5)
        self.conv_last = nn.Conv2d(64, n_class, 1)
        self.act_last = nn.Tanh()
        self.support_conv1 = nn.Conv2d(11, 512, 1)  # (bath,10+1) --> (batch,512)

    def forward(self, inp):
        x_original = self.conv_original_size0(inp[0])
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(inp[0])
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        cond = self.support_conv1(torch.unsqueeze(torch.unsqueeze(inp[1], 2), 2))  # ([8, 8]) --> Size([8, 512, 1, 1])
        layer4 = self.layer4_1x1(layer4 + cond)

        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        x = self.dropout(x)
        out = self.conv_last(x)
        out = self.act_last(out)

        return out
#====================================================================    
    
class StoneClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 3, 2, 1)
        self.conv2 = nn.Conv2d(8, 16, 3, 2, 1)
        self.conv3 = nn.Conv2d(16, 32, 3, 2, 1)
        self.fc1 = nn.Linear(32 * 3 * 3, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc3(x), dim=1)
        return x

class NeuralEnv:
    def __init__(self,
                 env_model_path: str,
                 reward_model_path: str,
                 device: str,
                 batch_size = 16,
                 reward_confidence=0.5,
                 stone_frac=0.0,
                 step_size=4,
                 max_step=14):
        '''
        input params:
            env_model_path    [str] : path to model s_next=model(s_curr,action)
            reward_model_path [str] : path to model reward=model(s_curr)
            device            [str] : one of {'cpu', 'cuda:0', 'cuda:1'}
            batch_size        [int] : len of batch
            reward_confidence [flt] : classificator's confidence
            stone_frac        [flt] : part of the initial states with guaranteed stones
            step_size         [int] : 
            max_step          [int] : 
        output params:
            all output-variables will be torch.tensors in the selected DEVICE
            all input-variables have to be torch.tensors in the selected DEVICE
        '''
        self.device = device
        self.batch_size = batch_size
        self.reward_confidence = reward_confidence
        self.stone_frac = stone_frac
        self.step_size = step_size
        self.max_step = max_step
        self.reward_frame_transform = transforms.Compose([transforms.CenterCrop(24)])
        self.frame_transform = A.Compose([A.Normalize(mean=(0.5,), std=(0.5,)),
                                          ToTensorV2(transpose_mask=False)])
        
        self.model = ResNetUNet_v2(3)
        self.model.load_state_dict(torch.load(env_model_path, map_location=self.device))
        self.model = self.model.to(self.device)
        self.model.eval()
        
        self.stone_classifier = StoneClassifier()
        self.stone_classifier.load_state_dict(torch.load(reward_model_path))
        self.stone_classifier = self.stone_classifier.to(self.device)
        self.stone_classifier.eval()

        self.df = pd.read_csv('../surviv_rl_data/dataset_inventory_v2.csv')
        self.df = self.df[self.df.zoom == 1].reset_index()
    #----------------------------------------------------------------------------------------------------
        
    def reset(self):
        '''
        output params:
            init_s     [float torch tensor [-1...1]] : batch of initial states (batch,3,96,96)
            init_supp  [float torch tensor]          : batch of initial support vector (batch,2)
        '''
        init_s = torch.zeros(self.batch_size,3,96,96).float()
        init_supp = torch.zeros(self.batch_size,2).float()
        
        for i in range(self.batch_size):
            j = np.random.randint(len(self.df))
            frame = load_image(self.df["video"][j], self.df["frame"][j])
            frame = self.frame_transform(image=frame)['image']
            supp = torch.tensor([self.df["sp"][j]/100,self.df["zoom"][j]/15]).float() 
            #if check_frame(frame)==True:
            #    init_s[i] = frame
            #    init_supp[i] = supp
            init_s[i] = frame
            init_supp[i] = supp
        return init_s.to(self.device),init_supp.to(self.device)
    #----------------------------------------------------------------------------------------------------
    
    def get_reward(self, state):
        '''
        input params:
            state [float torch.tensor [-1...1]] : batch of states (batch,3,96,96)
        output params:
            r      [float torch.tensor [0...1]]  : batch of rewards (batch,1)
        '''
        state = self.reward_frame_transform(state)
        with torch.no_grad():
            r = self.stone_classifier(state)[:,1].unsqueeze(1)
        r = (r>self.reward_confidence).float().detach()
        return r
    #----------------------------------------------------------------------------------------------------
    
    def step(self, s_curr, supp_curr, action):
        '''
        input params:
            s_curr    [float torch.tensor [-1...1]] : batch of current states (batch,3,96,96)
            supp_curr [float torch tensor]          : batch of current support vector (batch,2)
            action    [int torch tensor {1,...,8}]  : batch of chosen direction (batch,1)  
        output params:
            s_next    [float torch.tensor [-1...1]] : batch of next states (batch,3,96,96)
            supp_next [float torch tensor]          : batch of next support vector =supp_curr (batch,2)
            reward    [float torch.tensor [0...1]]  : batch of rewards (batch,1)
        '''
        action_ohe = F.one_hot(action.squeeze()-1, num_classes=8).float() # (batch,8)
        if len(action_ohe.shape) == 1:
            action_ohe = action_ohe[None]
        n =  torch.tensor([self.step_size/self.max_step]*self.batch_size)
        n = n.unsqueeze(1).float().to(self.device) # (batch,1)
        v = torch.cat([action_ohe,supp_curr,n], dim=1) # (batch,8+2+1)
        with torch.no_grad():
            s_next = self.model((s_curr,v)).detach()
        reward = self.get_reward(s_next)
        return s_next, supp_curr, reward

In [None]:
class CEMModel(nn.Module):
    def __init__(self, threshold: int = 70):
        super().__init__()
        
        self.threshold = threshold

        self.conv1 = nn.Conv2d(3, 16, 3, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, 2)
        self.conv3 = nn.Conv2d(32, 64, 3, 2)
        self.linear1 = nn.Linear(7744, 256)
        self.linear2 = nn.Linear(256, 8)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.linear1(torch.flatten(x, 1)))
        x = self.linear2(x)
        x = F.softmax(x, dim=1)
        return x

    def train_step(self, states: List[List[torch.Tensor]], actions: List[List[int]], rewards: np.ndarray, batch_size=32, device=torch.device("cpu")):
        threshold = np.percentile(rewards, 70)
        elite_session_indices = np.where(rewards > threshold)[0]
        self.train()
        optimizer = torch.optim.Adam(self.parameters(), lr=3e-5, weight_decay=1e-5)
        loss_fn = nn.CrossEntropyLoss()
        elite_states = []
        elite_actions = []

        for index in elite_session_indices:
            for state in states[index]:
                elite_states.append(state[None])
            for action in actions[index]:
                elite_actions.append(torch.LongTensor([action]))

        dataset = TensorDataset(torch.cat(elite_states), torch.cat(elite_actions))
        dataloader = DataLoader(dataset, batch_size=batch_size)
        for x, y_true in dataloader:
            optimizer.zero_grad()
            y_pred = self.forward(x.to(device))
            loss = loss_fn(y_pred, y_true.to(device))
            loss.backward()
            optimizer.step()

In [None]:
env = NeuralEnv("../best_models/resunet_v5.pth",
                "../best_models/nostone_stone_classifier.pth",
                "cpu",
                1)
device = torch.device("cpu")

In [None]:
agent = CEMModel().to(device)

epoches = 10
sessions_per_epoch = 10
steps_per_session = 10
for train_step in range(epoches):
    states = []
    actions = []
    rewards = np.zeros(sessions_per_epoch)
    for i in tqdm(range(sessions_per_epoch)):
        session_states = []
        session_actions = []

        s_curr, supp_curr = env.reset()
        for _ in range(steps_per_session):
            probs = agent(s_curr)[0].detach().numpy()
            chosen_action = np.random.choice(np.arange(len(probs)), p=probs)
            s_next, supp_next, reward = env.step(s_curr, supp_curr, torch.LongTensor([[chosen_action + 1]]).to(device))

            session_states.append(s_next[0])
            session_actions.append(chosen_action)
            rewards[i] += reward.detach().numpy()

            s_curr, supp_curr = s_next, supp_next
        states.append(session_states)
        actions.append(session_actions)
    agent.train_step(states, actions, rewards)
    print(f"{train_step}: {rewards.mean():.2f}")
    torch.save(agent.state_dict(), "../weights/agent.pth")