In [1]:
# Copyright (c) 2023, HLSS. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
from predictors import QCNet
from datamodules.argoverse_v2_datamodule import ArgoverseV2DataModule
import torch

torch.cuda.set_device(3)
DATA_ROOT = '~/dataset/argoverse/motion_forecasting'
BATCH_SIZE = 2

In [4]:
data_module = ArgoverseV2DataModule(DATA_ROOT,1,1,1,False)
data_module.setup()
dataset = data_module.train_dataset

In [5]:
model = QCNet(
            dataset = 'argoverse_v2', 
            input_dim = 2, 
            hidden_dim = 128, 
            output_dim = 2, 
            output_head = False, 
            num_historical_steps = 50, 
            num_future_steps = 60, 
            num_modes = 6, 
            num_recurrent_steps = 3, 
            num_freq_bands = 64, 
            num_map_layers = 1, 
            num_agent_layers = 2, 
            num_dec_layers = 2, 
            num_heads = 8, 
            head_dim = 16, 
            dropout = 0.1, 
            pl2pl_radius = 150, 
            time_span = 10, 
            pl2a_radius = 50, 
            a2a_radius = 50, 
            num_t2m_steps = 30, 
            pl2m_radius = 150, 
            a2m_radius = 150, 
            lr = 0.0005, 
            weight_decay = 0.0001, 
            T_max = 64, 
            submission_dir = './', 
            submission_file_name = 'submission'
        )

In [6]:
large8files=[
    'ad7e93e9-a705-43dc-be19-8d6530e113ef.pkl',
    '721cb053-b5d7-4fda-a9c5-e5057ae218fd.pkl',
    '0c1e8c81-eb87-40e3-b8ae-f2fb7c42c9ce.pkl',
    '6ade6821-7255-4aba-b8ca-991a21fdb375.pkl',
    '75b84bd2-3090-41f6-a4fd-263bd335ef15.pkl',
    '7f308986-6ed2-493a-8862-007f642c4cee.pkl',
    'a082c53b-4cc6-45ea-8519-efbf3aa73262.pkl',
    '59e210b5-564f-43f1-a45c-1d9d5e51e5f6.pkl'
]

In [7]:
from torch_geometric.data import Batch
batch = Batch.from_data_list([dataset.transform(dataset.get(dataset.processed_file_names.index(x))) for x in large8files[:BATCH_SIZE]])

In [13]:
model.cuda()
optimers, lr_schedulers = model.configure_optimizers()
loss = model.training_step(batch.cuda(), 0)
loss.backward()
for optimer in optimers:
    optimer.step()
    optimer.zero_grad()
for lr_schedulers in lr_schedulers:
    lr_schedulers.step()