In [None]:
import os
import pandas as pd
import numpy as np

In [None]:
data_v = [
       'Rural population growth (annual %)',
       'General government final consumption expenditure (annual % growth)',
       'Consumer price index (2010 = 100)',
       'Exports of goods and services (annual % growth)_x',
       'Urban population growth (annual %)',
       'Population growth (annual %)', 'Inflation, GDP deflator (annual %)',
       'Imports of goods and services (annual % growth)',
       'Final consumption expenditure (annual % growth)',
       'Unemployment, total (% of total labor force) (national estimate)',
       'Exports of goods and services (annual % growth)_y',
       'Inflation, consumer prices (annual %)',
       'Gross fixed capital formation (annual % growth)',
       'Households and NPISHs Final consumption expenditure (annual % growth)']

In [None]:
prompt_str = 'In this year, the index is {value}.'

In [None]:
prompt_dict = {
    "Rural population growth (annual %)": "The index 'Rural population growth (annual %)' measures the annual percentage increase or decrease in the rural population of a country. ", 
    "General government final consumption expenditure (annual % growth)": "The index 'General government final consumption expenditure (annual % growth)' measures the annual percentage increase or decrease in government spending on goods and services that are used for providing public services. ", 
    "Consumer price index (2010 = 100)": "The index 'Consumer price index (2010 = 100)' measures the average change over time in the prices paid by consumers for a basket of goods and services, with the base year set to 2010. This indicator is essential for tracking inflation, assessing the cost of living, and guiding monetary policy decisions. ",
    "Exports of goods and services (annual % growth)": "The index 'Exports of goods and services (annual % growth)' measures the annual percentage increase or decrease in the value of a country's exports of goods and services. ",
    "Urban population growth (annual %)": "The index 'Urban population growth (annual %)' measures the annual percentage increase or decrease in the population residing in urban areas. ", 
    "Population growth (annual %)": "The index 'Population growth (annual %)' measures the annual percentage increase or decrease in the total population of a country. ",
    "Inflation, GDP deflator (annual %)": "The index 'Inflation, GDP deflator (annual %)' measures the annual percentage change in the price level of all new, domestically produced, final goods and services in an economy. ",
    "Imports of goods and services (annual % growth)": "The index 'Imports of goods and services (annual % growth)' measures the annual percentage increase or decrease in the value of a country's imports of goods and services. ", 
    "Final consumption expenditure (annual % growth)": "The index 'Final consumption expenditure (annual % growth)' measures the annual percentage change in the total value of all goods and services consumed by households and government. ", 
    "Unemployment, total (% of total labor force) (national estimate)": "The index 'Unemployment, total (% of total labor force) (national estimate)' measures the percentage of the total labor force that is unemployed and actively seeking employment, based on national estimates. ",
    "Inflation, consumer prices (annual %)": "The index 'Inflation, consumer prices (annual %)' measures the annual percentage change in the average level of prices for consumer goods and services. ", 
    "Gross fixed capital formation (annual % growth)": "The index 'Gross fixed capital formation (annual % growth)' measures the annual percentage increase or decrease in investment in fixed assets such as buildings, machinery, and infrastructure. ",
    "Households and NPISHs Final consumption expenditure (annual % growth)": "The index 'Households and NPISHs Final consumption expenditure (annual % growth)' measures the annual percentage change in the spending by households and Non-Profit Institutions Serving Households (NPISHs) on goods and services. ", 

}

In [None]:
len(prompt_dict.keys())

In [None]:
df_mlp_data = pd.read_excel('integrated_yearly_data.xlsx', index_col=0)
df_mlp_data.columns

In [None]:
df_mlp_data =  df_mlp_data[[
       'Country Name', 'Country Code', 'year',
       'Rural population growth (annual %)',
       'General government final consumption expenditure (annual % growth)',
       'Consumer price index (2010 = 100)',
       'Exports of goods and services (annual % growth)_x',
       'Urban population growth (annual %)',
       'Population growth (annual %)', 'Inflation, GDP deflator (annual %)',
       'Imports of goods and services (annual % growth)',
       'Final consumption expenditure (annual % growth)',
       'Unemployment, total (% of total labor force) (national estimate)',
       'Inflation, consumer prices (annual %)',
       'Gross fixed capital formation (annual % growth)',
       'Households and NPISHs Final consumption expenditure (annual % growth)',
       'GDP growth (annual %)'
]]
df_mlp_data.columns = [
       'Country Name', 'Country Code', 'year',
       'Rural population growth (annual %)',
       'General government final consumption expenditure (annual % growth)',
       'Consumer price index (2010 = 100)',
       'Exports of goods and services (annual % growth)',
       'Urban population growth (annual %)',
       'Population growth (annual %)', 'Inflation, GDP deflator (annual %)',
       'Imports of goods and services (annual % growth)',
       'Final consumption expenditure (annual % growth)',
       'Unemployment, total (% of total labor force) (national estimate)',
       'Inflation, consumer prices (annual %)',
       'Gross fixed capital formation (annual % growth)',
       'Households and NPISHs Final consumption expenditure (annual % growth)',
       'GDP growth (annual %)'
]


In [None]:
df_mlp_data.columns

In [None]:
len(df_mlp_data.columns)

In [None]:
df_mlp_data.head()

In [None]:
select_v_list = [
       'Rural population growth (annual %)',
       'General government final consumption expenditure (annual % growth)',
       'Consumer price index (2010 = 100)',
       'Exports of goods and services (annual % growth)',
       'Urban population growth (annual %)',
       'Population growth (annual %)', 'Inflation, GDP deflator (annual %)',
       'Imports of goods and services (annual % growth)',
       'Final consumption expenditure (annual % growth)',
       'Unemployment, total (% of total labor force) (national estimate)',
       'Inflation, consumer prices (annual %)',
       'Gross fixed capital formation (annual % growth)',
       'Households and NPISHs Final consumption expenditure (annual % growth)'
]

In [None]:
def norm(col):
    return (col - col.min())/(col.max() - col.min())

In [None]:
for v in select_v_list:
    df_mlp_data[v + '_norm'] = df_mlp_data[[v]].apply(lambda col: norm(col), axis=0)
    # break

In [None]:
df_mlp_data.head()

In [None]:
df_weo = df_mlp_data.copy()

## Load Model

In [None]:
import os

In [None]:
import torch
import flash_attn
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import torchvision.transforms as T
from PIL import Image

from torchvision.transforms.functional import InterpolationMode


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


def load_image(image_file, input_size=448, max_num=6):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

In [None]:
path = "~/local_model/InternVL-Chat-V1-5"

model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map='auto').eval()

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)

generation_config = dict(
    num_beams=1,
    max_new_tokens=512,
    do_sample=False,
)

In [None]:
def chat_img(self, tokenizer, pixel_values, question, generation_config,
         IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'):

    img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
    self.img_context_token_id = img_context_token_id
    if tokenizer.convert_tokens_to_ids('<|im_end|>') != 0:
        eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')  # 92542, InternLM2
    else:
        eos_token_id = tokenizer.eos_token_id

    from internvl_chat.internvl.conversation import get_conv_template

    template = get_conv_template(self.template)

    image_bs = pixel_values.shape[0]
    print(f'dynamic ViT batch size: {image_bs}')

    history = []
    image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * image_bs + IMG_END_TOKEN
    question = image_tokens + '\n' + question

    template.append_message(template.roles[0], question)
    template.append_message(template.roles[1], None)
    query = template.get_prompt()
    print(query)

    model_inputs = tokenizer(query, return_tensors='pt')
    input_ids = model_inputs['input_ids'].cuda()
    attention_mask = model_inputs['attention_mask'].cuda()
    generation_config['eos_token_id'] = eos_token_id
    return input_ids, attention_mask, generation_config

In [None]:
def chat_no_img(self, tokenizer, question, generation_config):

    if tokenizer.convert_tokens_to_ids('<|im_end|>') != 0:
        eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')  # 92542, InternLM2
    else:
        eos_token_id = tokenizer.eos_token_id

    from internvl_chat.internvl.conversation import get_conv_template

    template = get_conv_template(self.template)

    template.append_message(template.roles[0], question)
    template.append_message(template.roles[1], None)
    query = template.get_prompt()
    # print(query)
    
    model_inputs = tokenizer(query, return_tensors='pt')
    input_ids = model_inputs['input_ids'].cuda()
    attention_mask = model_inputs['attention_mask'].cuda()
    generation_config['eos_token_id'] = eos_token_id

    return input_ids, attention_mask, generation_config

In [None]:
def get_representation(model, tokenizer, question, generation_config, pixel_values=None):
    if pixel_values is not None:
        input_ids, attention_mask, generation_config = chat_img(model, tokenizer, pixel_values, question, generation_config)
    else:
        input_ids, attention_mask, generation_config = chat_no_img(model, tokenizer, question, generation_config)

    output = generate_test(model,
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            **generation_config
        )
    
    return output.hidden_states[-1]

In [None]:
from typing import Any, List, Optional, Tuple, Union

In [None]:
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
                          LlamaTokenizer)

In [None]:
@torch.no_grad()
def generate_test(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        input_ids: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        visual_features: Optional[torch.FloatTensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **generate_kwargs,
) -> torch.LongTensor:

    # assert self.img_context_token_id is not None
    if pixel_values is not None:
        if visual_features is not None:
            vit_embeds = visual_features
        else:
            vit_embeds = self.extract_feature(pixel_values)

        input_embeds = self.language_model.get_input_embeddings()(input_ids)
        B, N, C = input_embeds.shape
        input_embeds = input_embeds.reshape(B * N, C)

        input_ids = input_ids.reshape(B * N)
        selected = (input_ids == self.img_context_token_id)
        assert selected.sum() != 0
        input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)

        input_embeds = input_embeds.reshape(B, N, C)
    else:
        input_embeds = self.language_model.get_input_embeddings()(input_ids)

    output = self.language_model(
        inputs_embeds=input_embeds,
        attention_mask=attention_mask,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        use_cache=True
    )
    return output

In [None]:
df_weo

In [None]:
country_list = df_weo['Country Name'].unique()
country_list

In [None]:
def get_prompt(v, value):
    prompt = prompt_dict[v]
    prompt = prompt + prompt_str
    if '%' in v:
        prompt = prompt[:-1] + ' %.'
    return prompt.format(value=value)

In [None]:
get_prompt('Rural population growth (annual %)', 3)

## RUN

In [None]:
def run():
    for country in country_list:
        year_dict = {}
        df_country = df_weo.loc[df_weo['Country Name'] == country]
        year_list = df_country['year'].unique()

        for year in year_list:
            v_tensor_list = []
            for i in range(len(select_v_list)):
                v = select_v_list[i]
            # for v in select_v_list:
                value = df_country.loc[df_country['year'] == year][v].iloc[0]

                if (np.isnan(value) == True) or (value == None):
                    # v_tensor_list.append(np.zeros(6145))
                    continue
                    
                else:
                    norm_value = df_country.loc[df_country['year'] == year][v + '_norm'].iloc[0]
                    value = "%.4f" % value
                    prompt = get_prompt(v, value)

                    tensor = get_representation(model, tokenizer, prompt, generation_config,
                                                pixel_values=None)[-1][-1].cpu().float().numpy()
                    tensor = np.append(tensor, i)
                    tensor = np.append(tensor, norm_value)
                    v_tensor_list.append(tensor)

            year_dict[year] = [v_tensor_list, 
                               df_country.loc[df_country['year'] == year]['GDP growth (annual %)'].iloc[0]]
        print(country)
        pd.DataFrame(year_dict).to_json('results_mlp_13v_norm/' + country + '.jsonl')

In [None]:
run()

In [None]:
len(df_weo['Country Name'].unique())

## IF use light data

In [None]:
df_code = pd.read_excel('../light_data/IMF_ISO.xlsx', header=None)
df_code.columns = ['REF_AREA', 'ADM0_ISO']
df_code

In [None]:
df_light_data = pd.read_csv('../light_data/light_data.csv', index_col=0)
df_light_data = df_light_data.merge(df_code, how='left', on='ADM0_ISO')
df_light_data = df_light_data.dropna(subset=['REF_AREA'])
df_light_data['REF_AREA'] = df_light_data['REF_AREA'].astype(int)
df_light_data

In [None]:
df_light_data = df_light_data.loc[(df_light_data['year_month'] >= 201301) & \
                                  (df_light_data['year_month'] <= 202312)]

In [None]:
df_light_data = pd.DataFrame(df_light_data.groupby(['ADM0_ISO', 'year_month', 'REF_AREA']).max('sum').reset_index())

In [None]:
df_light_data['year'] = df_light_data['year_month'].apply(lambda x: int(str(x)[:4]))
# df_light_data

In [None]:
df_light_data = df_light_data.groupby(['ADM0_ISO', 'REF_AREA', 'year']).mean().reset_index()
df_light_data

In [None]:
df_light_data['norm'] = df_light_data[['mean']].apply(lambda col: norm(col), axis=0)
df_light_data

In [None]:
df_light_data['norm'] = df_light_data[['mean']].apply(lambda col: norm(col), axis=0)
df_light_data

In [None]:
def use_light_data(df_country, year):
    try:
        iso_code = df_country['REF_AREA'].iloc[0]
        # print(iso_code)
        # print(type(iso_code))
        df_temp = df_light_data.loc[df_light_data['REF_AREA'] == iso_code]
        df_temp = df_temp.loc[(df_temp['year_month'] >= int(str(year) + '01')) & \
                              (df_temp['year_month'] <= int(str(year) + '12'))]
    
        df_temp = df_temp.sort_values('year_month')
        # display(df_temp)
        df_temp = pd.DataFrame(df_temp.groupby(['ADM0_ISO', 'year_month', 'REF_AREA']).max('sum').reset_index())
        df_temp = df_temp.dropna()
        # display(df_temp)
        if df_temp.empty:
            return []
        prompt_list = []
        light_prompt_str = 'Nighttime light remote sensing data refers to the use of remote sensing technology to capture the distribution of lights on Earth at night. It can effectively reflect the spatial distribution of human activities and is therefore commonly used in remote sensing inversion of various socio-economic data. In this data, each pixel represents the light intensity of a geographical area of 500 meters by 500 meters. In this year, the total sum of light intensity of all pixels occupied by the country or region is {sum_}, the average is {mean_}. Based on this, what is the forecast for gross domestic product(GDP) growth in this year?'
 
        sum_ = "%.4f" % df_temp['sum'].mean()
        mean_ = "%.4f" % df_temp['mean'].mean()
        norm_ = df_temp['norm'].mean()
        prompt = light_prompt_str.format(sum_=sum_, mean_=mean_)
        # print(prompt)
        prompt_list.append((prompt, norm_))

        return prompt_list

    except:
        return []


In [None]:
def run():
    for country in country_list:
        year_dict = {}
        df_country = df_weo.loc[df_weo['Country Name'] == country]
        year_list = df_country['year'].unique()

        for year in year_list:
            v_tensor_list = []
            for i in range(len(select_v_list)):
                v = select_v_list[i]
            # for v in select_v_list:
                value = df_country.loc[df_country['year'] == year][v].iloc[0]

                if (np.isnan(value) == True) or (value == None):
                    # v_tensor_list.append(np.zeros(6145))
                    continue
                    
                else:
                    norm_value = df_country.loc[df_country['year'] == year][v + '_norm'].iloc[0]
                    value = "%.4f" % value
                    prompt = get_prompt(v, value)

                    tensor = get_representation(model, tokenizer, prompt, generation_config,
                                                pixel_values=None)[-1][-1].cpu().float().numpy()
                    tensor = np.append(tensor, i+1)
                    tensor = np.append(tensor, norm_value)
                    v_tensor_list.append(tensor)

            if year > 2012:
                prompt_list = use_light_data(df_country, year)
                if prompt_list != []:
                    for prompt in prompt_list:
                        tensor = get_representation(model, tokenizer, prompt[0], generation_config,
                                pixel_values=None)[-1][-1].cpu().float().numpy()
                        tensor = np.append(tensor, 30)
                        tensor = np.append(tensor, prompt[1])
                        v_tensor_list.append(tensor)

            year_dict[year] = [v_tensor_list, 
                               df_country.loc[df_country['year'] == year]['GDP growth (annual %)'].iloc[0]]
        print(country)
        pd.DataFrame(year_dict).to_json('results_mlp_13v_norm_light_sms/' + country + '.jsonl')