In [14]:
#Set the Pyspark environment  variables
import os, findspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F


os.environ['SPARK_HOME'] = "/Users/shrutimac/documents/Apps/spark"
findspark.init(os.environ["SPARK_HOME"])
os.environ["PYSPARK_DRIVER_PYTHON"] = "jupyter"
os.environ["PYSPARK_DRIVER_PYTHON_OPTS"] = "lab"
os.environ["PYSPARK_PYTHON"] = "python"

In [15]:
spark = (
    SparkSession.builder
    .appName("TariffsTradeAvailability")
    .getOrCreate()
)

In [16]:
base_path = "/Users/shrutimac/Documents/big data/Final Project/Data Processed"

trade_path  = f"{base_path}/merged_trade_cleaned.csv"   # or .parquet
tariff_path = f"{base_path}/merged_tariff_cleaned.csv" 

df_trade = spark.read.csv(trade_path, header=True, inferSchema=True)
df_tariff = spark.read.csv(tariff_path, header=True, inferSchema=True)

                                                                                

In [17]:
df_trade.show(5)

+------------+------------+-----------+------------+-----------+-----------+----+-------------+-------------+----------------------+
|Nomenclature|ReporterISO3|ProductCode|ReporterName|PartnerISO3|PartnerName|Year|TradeFlowName|TradeFlowCode|TradeValue in 1000 USD|
+------------+------------+-----------+------------+-----------+-----------+----+-------------+-------------+----------------------+
|          H3|         BRA|       1212|      Brazil|        KOR|South Korea|2008|       Import|            5|               122.581|
|          H3|         BRA|       1212|      Brazil|        PHL|Philippines|2007|       Import|            5|               109.206|
|          H3|         BRA|       1212|      Brazil|        PHL|Philippines|2008|       Import|            5|               331.613|
|          H3|         BRA|       1212|      Brazil|        TUR|     Turkey|2007|       Import|            5|                 0.923|
|          H3|         BRA|       1213|      Brazil|        ARG|  Arg

In [18]:
df_tariff.show(5)

+--------------+------------+--------+-------------+-------+--------------------+-------+------------+-----------+----------+------------+--------+--------------+----------------+------------------+------------+------------+------------------+--------------------+-------------------------+-------------------------+----------------+
|Selected Nomen|Native Nomen|Reporter|Reporter Name|Product|        Product Name|Partner|Partner Name|Tariff Year|Trade Year|Trade Source|DutyType|Simple Average|Weighted Average|Standard Deviation|Minimum Rate|Maximum Rate|Nbr of Total Lines|Nbr of DomesticPeaks|Nbr of InternationalPeaks|Imports Value in 1000 USD|Binding Coverage|
+--------------+------------+--------+-------------+-------+--------------------+-------+------------+-----------+----------+------------+--------+--------------+----------------+------------------+------------+------------+------------------+--------------------+-------------------------+-------------------------+----------------

In [19]:
df_trade.printSchema()
df_trade.count()

root
 |-- Nomenclature: string (nullable = true)
 |-- ReporterISO3: string (nullable = true)
 |-- ProductCode: integer (nullable = true)
 |-- ReporterName: string (nullable = true)
 |-- PartnerISO3: string (nullable = true)
 |-- PartnerName: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- TradeFlowName: string (nullable = true)
 |-- TradeFlowCode: integer (nullable = true)
 |-- TradeValue in 1000 USD: double (nullable = true)



5894119

In [20]:
df_tariff.printSchema()
df_tariff.count()

root
 |-- Selected Nomen: string (nullable = true)
 |-- Native Nomen: string (nullable = true)
 |-- Reporter: integer (nullable = true)
 |-- Reporter Name: string (nullable = true)
 |-- Product: integer (nullable = true)
 |-- Product Name: string (nullable = true)
 |-- Partner: integer (nullable = true)
 |-- Partner Name: string (nullable = true)
 |-- Tariff Year: integer (nullable = true)
 |-- Trade Year: integer (nullable = true)
 |-- Trade Source: string (nullable = true)
 |-- DutyType: string (nullable = true)
 |-- Simple Average: double (nullable = true)
 |-- Weighted Average: double (nullable = true)
 |-- Standard Deviation: double (nullable = true)
 |-- Minimum Rate: double (nullable = true)
 |-- Maximum Rate: double (nullable = true)
 |-- Nbr of Total Lines: double (nullable = true)
 |-- Nbr of DomesticPeaks: double (nullable = true)
 |-- Nbr of InternationalPeaks: double (nullable = true)
 |-- Imports Value in 1000 USD: double (nullable = true)
 |-- Binding Coverage: string (n

1874056

### Basic cleaning

In [21]:
df_trade = (
    df_trade
    .withColumn("Year", F.col("Year").cast("int"))
    .withColumn("ReporterISO3", F.upper(F.col("ReporterISO3")))
    .withColumn("PartnerISO3", F.upper(F.col("PartnerISO3")))
    .withColumn("ProductCode", F.col("ProductCode").cast("string"))
    .withColumnRenamed("TradeValue in 1000 USD", "TradeValueKUSD")
)


In [22]:
df_tariff = (
    df_tariff
    .withColumn("TariffYear", F.col("Tariff Year").cast("int"))
    .withColumn("TradeYear",  F.col("Trade Year").cast("int"))
    .withColumn("ReporterISO3", F.upper(F.col("Reporter")))
    .withColumn("PartnerISO3",  F.upper(F.col("Partner")))
    .withColumn("ProductCode",  F.col("Product").cast("string"))
)


### Defining our “focus set” of countries and years

In [23]:
focus_countries = [
    "USA","CHN","IND","DEU","JPN","BRA","FRA","KOR",
    "CAN","GBR","ITA","ZAF","MEX","AUS","SGP","IDN",
    "VNM","THA","MYS","PHL","TUR","ARE","SAU","ARG",
    "CHL","NGA","EGY","RUS","NLD","ESP"
]

# Trade expected years
trade_years_expected  = list(range(2007, 2025))  # 2007–2024
tariff_years_expected = list(range(2007, 2024))  # example: 2007–2023


# High-level availability check – Trade

In [24]:
# 4.1 Overall year coverage
df_trade.select(
    F.min("Year").alias("min_year"),
    F.max("Year").alias("max_year")
).show()


+--------+--------+
|min_year|max_year|
+--------+--------+
|    2007|    2024|
+--------+--------+



                                                                                

In [25]:
# Count rows per year:
trade_year_counts = (
    df_trade.groupBy("Year")
    .count()
    .orderBy("Year")
)

trade_year_counts.show(50)


+----+------+
|Year| count|
+----+------+
|2007|283017|
|2008|284468|
|2009|346780|
|2010|351201|
|2011|354787|
|2012|356654|
|2013|356101|
|2014|316328|
|2015|315605|
|2016|315448|
|2017|320490|
|2018|322057|
|2019|324018|
|2020|319096|
|2021|322502|
|2022|324337|
|2023|321051|
|2024|360179|
+----+------+



## Missing years per reporter (trade)
Goal: for each reporter, which expected years are missing?

Create a “complete grid” of (ReporterISO3, Year).

Left-anti join with the actual data → those are the missing combos.

In [26]:
# All reporters in the filtered dataset
reporters = df_trade.select("ReporterISO3").distinct()

# Creating a DataFrame of expected years
years_df = spark.createDataFrame(
    [(y,) for y in trade_years_expected],
    ["Year"]
)

# Full expected grid (ReporterISO3, Year)
expected_trade = reporters.crossJoin(years_df)

# Actual combos present
actual_trade = df_trade.select("ReporterISO3", "Year").distinct()

missing_trade_years = (
    expected_trade
    .join(actual_trade, ["ReporterISO3", "Year"], how="left_anti")
    .orderBy("ReporterISO3", "Year")
)

missing_trade_years.show(100)

                                                                                

+------------+----+
|ReporterISO3|Year|
+------------+----+
|         IND|2007|
|         IND|2008|
+------------+----+



**Might be possible India didn't start trading by 2019 (Need to fact check this)**

## Country coverage by product & year (trade)
we want to also find a way to identify which countries data is available for a particular category of product or year

**How many countries per product-year?**

In [27]:
trade_avail_product = (
    df_trade
    .filter(F.col("PartnerISO3").isin(focus_countries))
    .groupBy("Year", "ProductCode")
    .agg(
        F.countDistinct("ReporterISO3").alias("n_reporters"),
        F.countDistinct("PartnerISO3").alias("n_partners")
    )
    .orderBy("Year", "ProductCode")
)

trade_avail_product.show(10)




+----+-----------+-----------+----------+
|Year|ProductCode|n_reporters|n_partners|
+----+-----------+-----------+----------+
|2007|       1001|          5|        28|
|2007|       1002|          3|        14|
|2007|       1003|          5|        25|
|2007|       1004|          5|        22|
|2007|       1005|          5|        27|
|2007|       1006|          5|        29|
|2007|       1007|          5|        24|
|2007|       1008|          5|        29|
|2007|       1101|          5|        27|
|2007|       1102|          5|        29|
+----+-----------+-----------+----------+
only showing top 10 rows


                                                                                

In [29]:
import pyspark.sql.functions as F

def show_trade_slice(year, product_code, limit=500):
    slice_df = (
        df_trade
        .filter(
            (F.col("Year") == year) &
            (F.col("ProductCode") == product_code)
        )
        .select(
            "Year",
            "ProductCode",
            "ReporterISO3", "ReporterName",
            "PartnerISO3", "PartnerName"
        )
        .distinct()
        .orderBy("ReporterISO3", "PartnerISO3")
        .limit(limit)
    )
    pdf = slice_df.toPandas()
    display(pdf)

# EXAMPLE:
show_trade_slice(2007, 1001)


Unnamed: 0,Year,ProductCode,ReporterISO3,ReporterName,PartnerISO3,PartnerName
0,2007,1001,BRA,Brazil,ARG,Argentina
1,2007,1001,BRA,Brazil,CAN,Canada
2,2007,1001,BRA,Brazil,FRA,France
3,2007,1001,BRA,Brazil,IND,India
4,2007,1001,BRA,Brazil,JPN,Japan
...,...,...,...,...,...,...
63,2007,1001,FRA,France,MYS,Malaysia
64,2007,1001,FRA,France,NGA,Nigeria
65,2007,1001,FRA,France,NLD,Netherlands
66,2007,1001,FRA,France,RUS,Russian Federation


## Availability checks – Tariff data

In [30]:
# Year coverage
df_tariff.select(
    F.min("TariffYear").alias("min_tariff_year"),
    F.max("TariffYear").alias("max_tariff_year")
).show()

tariff_year_counts = (
    df_tariff.groupBy("TariffYear")
    .count()
    .orderBy("TariffYear")
)

tariff_year_counts.show(50)


+---------------+---------------+
|min_tariff_year|max_tariff_year|
+---------------+---------------+
|           2007|           2022|
+---------------+---------------+

+----------+------+
|TariffYear| count|
+----------+------+
|      2007|113908|
|      2008|115294|
|      2009|114899|
|      2010|117226|
|      2011|119814|
|      2012|119322|
|      2013|119245|
|      2014|119365|
|      2015|119869|
|      2016|119250|
|      2017|121348|
|      2018|124910|
|      2019|123522|
|      2020|121258|
|      2021|125647|
|      2022| 79179|
+----------+------+



**We have tariff data available till the year 2022**

### Country coverage by product-year (tariff)

In [31]:
import pyspark.sql.functions as F

def show_tariff_slice(year, product_code, duty_type=None, limit=500):
    # build filter
    cond = (
        (F.col("Tariff Year") == year) &
        (F.col("Product") == str(product_code))   # make sure both are strings
    )
    if duty_type is not None:
        cond = cond & (F.col("DutyType") == duty_type)

    slice_df = (
        df_tariff
        .filter(cond)
        .select(
            "Tariff Year",
            "Product", "Product Name",
            "Reporter", "Reporter Name",
            "Partner", "Partner Name",
            "DutyType",
            "Simple Average", "Weighted Average",
            "Minimum Rate", "Maximum Rate",
            "Binding Coverage",
            "Imports Value in 1000 USD"
        )
        .distinct()
        .orderBy("Reporter", "Partner", "DutyType")
        .limit(limit)
    )

    pdf = slice_df.toPandas()
    display(pdf)


In [32]:
show_tariff_slice(2019, "1001")                 # all duty types

Unnamed: 0,Tariff Year,Product,Product Name,Reporter,Reporter Name,Partner,Partner Name,DutyType,Simple Average,Weighted Average,Minimum Rate,Maximum Rate,Binding Coverage,Imports Value in 1000 USD
0,2019,1001,Wheat and meslin.,76,Brazil,32,Argentina,MFN,3.33,9.99,0.00,10.00,,1239079.236
1,2019,1001,Wheat and meslin.,76,Brazil,124,Canada,MFN,10.00,10.00,10.00,10.00,,27532.349
2,2019,1001,Wheat and meslin.,76,Brazil,250,France,MFN,10.00,10.00,10.00,10.00,,830.471
3,2019,1001,Wheat and meslin.,76,Brazil,643,Russian Federation,MFN,10.00,10.00,10.00,10.00,,18350.918
4,2019,1001,Wheat and meslin.,124,Canada,32,Argentina,MFN,54.58,38.25,0.00,76.50,,3297.624
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
66,2019,1001,Wheat and meslin.,840,United States,380,Italy,MFN,2.23,1.75,1.67,2.80,,605.752
67,2019,1001,Wheat and meslin.,840,United States,484,Mexico,MFN,2.80,2.80,2.80,2.80,,8.840
68,2019,1001,Wheat and meslin.,840,United States,784,United Arab Emirates,MFN,1.67,1.67,1.67,1.67,,10.951
69,2019,1001,Wheat and meslin.,840,United States,792,Turkey,MFN,1.67,1.67,1.67,1.67,,2.386


In [33]:
show_tariff_slice(2019, "1001", duty_type="AHS")  # only a specific duty type


Unnamed: 0,Tariff Year,Product,Product Name,Reporter,Reporter Name,Partner,Partner Name,DutyType,Simple Average,Weighted Average,Minimum Rate,Maximum Rate,Binding Coverage,Imports Value in 1000 USD


# Merging the trade and tariff datasets into one
Because tariffs stop at 2022, any analysis that needs tariffs will be limited to 2007–2022.

We basically want: One row per (Reporter, Partner, Product, Year) with both Trade + Tariff info.

For this project it’s simplest to use **trade year** as the join year:

trade.Year ↔ tariff."Trade Year"

### First, standardise types and rename tariff columns:

In [62]:
import pyspark.sql.functions as F

# ---------- TRADE CLEAN ----------
trade_clean_names = (
    df_trade
    # year as int
    .withColumn("Trade_Year", F.col("Year").cast("int"))
    # product code as string
    .withColumn("Trade_ProductCode", F.col("ProductCode").cast("string"))
    # normalized reporter / partner names
    .withColumn("Trade_ReporterName_norm", F.lower(F.trim(F.col("ReporterName"))))
    .withColumn("Trade_PartnerName_norm",  F.lower(F.trim(F.col("PartnerName"))))
)

# ---------- TARIFF CLEAN ----------
tariff_clean_names = (
    df_tariff
    # year to match trade year
    .withColumn("Tariff_TradeYear", F.col("Trade Year").cast("int"))
    # product code as string
    .withColumn("Tariff_ProductCode", F.col("Product").cast("string"))
    # normalized reporter / partner names
    .withColumn("Tariff_ReporterName_norm", F.lower(F.trim(F.col("Reporter Name"))))
    .withColumn("Tariff_PartnerName_norm",  F.lower(F.trim(F.col("Partner Name"))))
)


In [63]:
tr = trade_clean_names.alias("tr")
ta = tariff_clean_names.alias("ta")

df_trade_tariff_names = (
    tr.join(
        ta,
        (
            (F.col("tr.Trade_Year") == F.col("ta.Tariff_TradeYear")) &
            (F.col("tr.Trade_ProductCode") == F.col("ta.Tariff_ProductCode")) &
            (F.col("tr.Trade_ReporterName_norm") == F.col("ta.Tariff_ReporterName_norm")) &
            (F.col("tr.Trade_PartnerName_norm")  == F.col("ta.Tariff_PartnerName_norm"))
        ),
        how="left"        # keep all trade rows; tariff columns may be null
    )
)


In [64]:
df_final = (
    df_trade_tariff_names
    .select(
        # Keys / identifiers (from trade)
        F.col("tr.Trade_Year").alias("Year"),
        F.col("tr.ReporterName").alias("ReporterName"),
        F.col("tr.PartnerName").alias("PartnerName"),
        F.col("tr.Trade_ProductCode").alias("ProductCode"),
        F.col("tr.TradeFlowName"),
        F.col("tr.TradeValueKUSD"),

        # Some tariff metadata
        F.col("ta.DutyType").alias("Tariff_DutyType"),
        F.col("ta.Tariff_TradeYear").alias("Tariff_Year"),
        F.col("ta.Trade Source").alias("Tariff_TradeSource"),

        # Tariff rates
        F.col("ta.Simple Average").alias("Tariff_SimpleAvg"),
        F.col("ta.Weighted Average").alias("Tariff_WeightedAvg"),
        F.col("ta.Minimum Rate").alias("Tariff_MinRate"),
        F.col("ta.Maximum Rate").alias("Tariff_MaxRate"),
        F.col("ta.Imports Value in 1000 USD").alias("Tariff_ImportsKUSD")
    )
)


### Sanity checks to be sure the join worked
**How many trade rows got a tariff?**

In [65]:
rows_total = df_final.count()
rows_with_tariff = df_final.filter(F.col("Tariff_SimpleAvg").isNotNull()).count()

print("Total trade rows:     ", rows_total)
print("Rows with tariff data:", rows_with_tariff)
print("Coverage %:           ", rows_with_tariff * 100.0 / rows_total)


[Stage 115:>                                                      (0 + 10) / 11]

Total trade rows:      5894119
Rows with tariff data: 2272708
Coverage %:            38.55890931282521


                                                                                

**Coverage by year (remember tariffs only go to 2022)**

In [66]:
coverage_by_year = (
    df_final
    .groupBy("Year")
    .agg(
        F.count("*").alias("n_trade_rows"),
        F.sum(
            F.when(F.col("Tariff_SimpleAvg").isNotNull(), 1).otherwise(0)
        ).alias("n_with_tariff")
    )
    .withColumn("coverage_pct", F.col("n_with_tariff") * 100.0 / F.col("n_trade_rows"))
    .orderBy("Year")
)

coverage_by_year.show(40, truncate=False)


[Stage 124:>                                                      (0 + 10) / 11]

+----+------------+-------------+------------------+
|Year|n_trade_rows|n_with_tariff|coverage_pct      |
+----+------------+-------------+------------------+
|2007|283017      |121997       |43.10589116554836 |
|2008|284468      |122680       |43.12611611850894 |
|2009|346780      |144913       |41.78816540746295 |
|2010|351201      |148605       |42.31337610086532 |
|2011|354787      |153252       |43.19549476164572 |
|2012|356654      |150551       |42.2120598675467  |
|2013|356101      |150389       |42.232119539119516|
|2014|316328      |144884       |45.80182595280848 |
|2015|315605      |145296       |46.03729345225836 |
|2016|315448      |144809       |45.90582282975324 |
|2017|320490      |149664       |46.69849293269681 |
|2018|322057      |152204       |47.2599570883415  |
|2019|324018      |153276       |47.3047793641094  |
|2020|319096      |149109       |46.72857071226214 |
|2021|322502      |152556       |47.30389268903759 |
|2022|324337      |88523        |27.2935249447

                                                                                

**Inspect a few matched rows manually**

In [68]:
df_final.filter(
    (F.col("Year") == 2009) &
    (F.col("ProductCode") == "2849")  # example
).show(50, truncate=False)




+----+--------------+--------------------+-----------+-------------+--------------+---------------+-----------+------------------+----------------+------------------+--------------+--------------+------------------+
|Year|ReporterName  |PartnerName         |ProductCode|TradeFlowName|TradeValueKUSD|Tariff_DutyType|Tariff_Year|Tariff_TradeSource|Tariff_SimpleAvg|Tariff_WeightedAvg|Tariff_MinRate|Tariff_MaxRate|Tariff_ImportsKUSD|
+----+--------------+--------------------+-----------+-------------+--------------+---------------+-----------+------------------+----------------+------------------+--------------+--------------+------------------+
|2009|Canada        |Japan               |2849       |Export       |927.805       |MFN            |2009       |WTO               |0.0             |0.0               |0.0           |0.0           |367.372           |
|2009|United States |Japan               |2849       |Export       |1993.109      |MFN            |2009       |WTO               |2.26  

                                                                                

**How many rows have missing TradeValueKUSD?**

In [69]:
missing_trade_value = df_final.filter(F.col("TradeValueKUSD").isNull()).count()
total_rows = df_final.count()

print("Missing TradeValueKUSD:", missing_trade_value)
print("Total rows:", total_rows)
print("Percent missing:", (missing_trade_value / total_rows) * 100)


[Stage 143:>                                                      (0 + 10) / 11]

Missing TradeValueKUSD: 0
Total rows: 5894119
Percent missing: 0.0


                                                                                

Which column tells us the tariff imposed during the trade?
Primary tariff rate columns:
| Column                 | Meaning                                                                                   |
| ---------------------- | ----------------------------------------------------------------------------------------- |
| **Tariff_SimpleAvg**   | *Simple average of tariff rates* for that product between reporter → partner in that year |
| **Tariff_WeightedAvg** | *Weighted average* tariff rate (weighted by import values)                                |
| **Tariff_MinRate**     | Minimum tariff applied for that HS code                                                   |
| **Tariff_MaxRate**     | Maximum tariff applied                                                                    |
| **Tariff_DutyType**    | MFN / EAP / Bound etc.                                                                    |


**Which one should we use for analysis?**

Use Tariff_WeightedAvg: This is the industry-standard measure because it represents the REAL tariff burden weighted by the importance of each tariff line.

If WeightedAvg is missing, fallback to SimpleAvg.

So your “tariff imposed” column = Tariff_WeightedAvg (primary).

**how many trades have tariff information?**

In [70]:
df_final.filter(F.col("Tariff_WeightedAvg").isNotNull()).count()

                                                                                

2272708

In [71]:
df_final.filter(F.col("Tariff_SimpleAvg").isNotNull()).count()

                                                                                

2272708

In [72]:
df_final = df_final.withColumn(
    "EffectiveTariff",
    F.coalesce(F.col("Tariff_WeightedAvg"), F.col("Tariff_SimpleAvg"))
)


Now EffectiveTariff is the single column representing the tariff rate for each trade row.

In [73]:
df_final.filter(F.col("EffectiveTariff").isNotNull()).count()

                                                                                

2272708

In [74]:
# Save as Parquet
output_path_parquet = "/Users/shrutimac/Documents/big data/Final Project/Data Processed/df_final.parquet"

df_final.write.mode("overwrite").parquet(output_path_parquet)

print("Saved Parquet:", output_path_parquet)


25/11/21 15:11:57 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/11/21 15:11:57 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
25/11/21 15:11:57 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 76.00% for 10 writers
25/11/21 15:11:59 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
25/11/21 15:11:59 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers

Saved Parquet: /Users/shrutimac/Documents/big data/Final Project/Data Processed/df_final.parquet


                                                                                

In [75]:
# Save as CSV
output_path_csv = "/Users/shrutimac/Documents/big data/Final Project/Data Processed/df_final.csv"

(df_final
    .coalesce(1)        # put into one CSV file
    .write
    .mode("overwrite")
    .option("header", True)
    .csv(output_path_csv)
)

print("Saved CSV folder:", output_path_csv)


[Stage 184:>                                                        (0 + 1) / 1]

Saved CSV folder: /Users/shrutimac/Documents/big data/Final Project/Data Processed/df_final.csv


                                                                                